Commit 84e952c230

Andrew Kelley <superjoe30@gmail.com>
2018-03-10 07:12:22
fix await multithreaded data race
coro return was reading from a value that coro await was writing to. that wasn't how it was designed to work, it was an implementation mistake. this commit also has some work-in-progress code for fixing error return traces across suspend points.
1 parent 3b3649b
src/all_types.hpp
@@ -61,6 +61,7 @@ struct IrExecutable {
     IrInstruction *coro_handle;
     IrInstruction *coro_awaiter_field_ptr; // this one is shared and in the promise
     IrInstruction *coro_result_ptr_field_ptr;
+    IrInstruction *coro_result_field_ptr;
     IrInstruction *await_handle_var_ptr; // this one is where we put the one we extracted from the promise
     IrBasicBlock *coro_early_final;
     IrBasicBlock *coro_normal_final;
@@ -1281,7 +1282,7 @@ struct FnTableEntry {
     bool is_cold;
 
     ZigList<FnExport> export_list;
-    bool calls_errorable_function;
+    bool calls_or_awaits_errorable_fn;
 };
 
 uint32_t fn_table_entry_hash(FnTableEntry*);
@@ -2038,6 +2039,7 @@ enum IrInstructionId {
     IrInstructionIdCoroAllocHelper,
     IrInstructionIdAtomicRmw,
     IrInstructionIdPromiseResultType,
+    IrInstructionIdAwaitBookkeeping,
 };
 
 struct IrInstruction {
@@ -2985,6 +2987,12 @@ struct IrInstructionPromiseResultType {
     IrInstruction *promise_type;
 };
 
+struct IrInstructionAwaitBookkeeping {
+    IrInstruction base;
+
+    IrInstruction *promise_result_type;
+};
+
 static const size_t slice_ptr_index = 0;
 static const size_t slice_len_index = 1;
 
src/analyze.cpp
@@ -5856,9 +5856,11 @@ uint32_t get_coro_frame_align_bytes(CodeGen *g) {
     return g->pointer_size_bytes * 2;
 }
 
+bool type_can_fail(TypeTableEntry *type_entry) {
+    return type_entry->id == TypeTableEntryIdErrorUnion || type_entry->id == TypeTableEntryIdErrorSet;
+}
+
 bool fn_type_can_fail(FnTypeId *fn_type_id) {
-    TypeTableEntry *return_type = fn_type_id->return_type;
-    return return_type->id == TypeTableEntryIdErrorUnion || return_type->id == TypeTableEntryIdErrorSet ||
-        fn_type_id->cc == CallingConventionAsync;
+    return type_can_fail(fn_type_id->return_type) || fn_type_id->cc == CallingConventionAsync;
 }
 
src/analyze.hpp
@@ -195,6 +195,7 @@ TypeTableEntry *get_auto_err_set_type(CodeGen *g, FnTableEntry *fn_entry);
 
 uint32_t get_coro_frame_align_bytes(CodeGen *g);
 bool fn_type_can_fail(FnTypeId *fn_type_id);
+bool type_can_fail(TypeTableEntry *type_entry);
 bool fn_eval_cacheable(Scope *scope);
 
 #endif
src/codegen.cpp
@@ -4251,6 +4251,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
         case IrInstructionIdExport:
         case IrInstructionIdErrorUnion:
         case IrInstructionIdPromiseResultType:
+        case IrInstructionIdAwaitBookkeeping:
             zig_unreachable();
 
         case IrInstructionIdReturn:
@@ -5279,7 +5280,7 @@ static void do_code_gen(CodeGen *g) {
         uint32_t err_ret_trace_arg_index = get_err_ret_trace_arg_index(g, fn_table_entry);
         if (err_ret_trace_arg_index != UINT32_MAX) {
             g->cur_err_ret_trace_val = LLVMGetParam(fn, err_ret_trace_arg_index);
-        } else if (g->have_err_ret_tracing && fn_table_entry->calls_errorable_function) {
+        } else if (g->have_err_ret_tracing && fn_table_entry->calls_or_awaits_errorable_fn) {
             // TODO call graph analysis to find out what this number needs to be for every function
             static const size_t stack_trace_ptr_count = 30;
 
src/ir.cpp
@@ -707,6 +707,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionPromiseResultTyp
     return IrInstructionIdPromiseResultType;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionAwaitBookkeeping *) {
+    return IrInstructionIdAwaitBookkeeping;
+}
+
 template<typename T>
 static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) {
     T *special_instruction = allocate<T>(1);
@@ -2656,6 +2660,17 @@ static IrInstruction *ir_build_promise_result_type(IrBuilder *irb, Scope *scope,
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_await_bookkeeping(IrBuilder *irb, Scope *scope, AstNode *source_node,
+        IrInstruction *promise_result_type)
+{
+    IrInstructionAwaitBookkeeping *instruction = ir_build_instruction<IrInstructionAwaitBookkeeping>(irb, scope, source_node);
+    instruction->promise_result_type = promise_result_type;
+
+    ir_ref_instruction(promise_result_type, irb->current_basic_block);
+
+    return &instruction->base;
+}
+
 static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_scope, size_t *results) {
     results[ReturnKindUnconditional] = 0;
     results[ReturnKindError] = 0;
@@ -2734,13 +2749,16 @@ static IrInstruction *ir_gen_async_return(IrBuilder *irb, Scope *scope, AstNode
     FnTableEntry *fn_entry = exec_fn_entry(irb->exec);
     bool is_async = fn_entry != nullptr && fn_entry->type_entry->data.fn.fn_type_id.cc == CallingConventionAsync;
     if (!is_async) {
+        //if (irb->codegen->have_err_ret_tracing) {
+        //    IrInstruction *stack_trace_ptr = ir_build_error_return_trace_nonnull(irb, scope, node);
+        //    ir_build_save_err_ret_addr(irb, scope, node, stack_trace_ptr);
+        //}
         IrInstruction *return_inst = ir_build_return(irb, scope, node, return_value);
         return_inst->is_gen = is_generated_code;
         return return_inst;
     }
 
-    IrInstruction *result_ptr = ir_build_load_ptr(irb, scope, node, irb->exec->coro_result_ptr_field_ptr);
-    ir_build_store_ptr(irb, scope, node, result_ptr, return_value);
+    ir_build_store_ptr(irb, scope, node, irb->exec->coro_result_field_ptr, return_value);
     IrInstruction *promise_type_val = ir_build_const_type(irb, scope, node,
             get_maybe_type(irb->codegen, irb->codegen->builtin_types.entry_promise));
     // TODO replace replacement_value with @intToPtr(?promise, 0x1) when it doesn't crash zig
@@ -2756,6 +2774,22 @@ static IrInstruction *ir_gen_async_return(IrBuilder *irb, Scope *scope, AstNode
     // the above blocks are rendered by ir_gen after the rest of codegen
 }
 
+//static void ir_gen_save_err_ret_addr(IrBuilder *irb, Scope *scope, AstNode *node, bool is_async) {
+//    if (!irb->codegen->have_err_ret_tracing)
+//        return;
+//
+//    if (is_async) {
+//        IrInstruction *err_ret_addr_ptr = ir_build_load_ptr(irb, scope, node, irb->exec->coro_err_ret_addr_ptr);
+//        IrInstruction *return_address_ptr = ir_build_return_address(irb, scope, node);
+//        IrInstruction *return_address_usize = ir_build_ptr_to_int(irb, scope, node, return_address_ptr);
+//        ir_build_store_ptr(irb, scope, node, err_ret_addr_ptr, return_address_usize);
+//        return;
+//    }
+//
+//    IrInstruction *stack_trace_ptr = ir_build_error_return_trace_nonnull(irb, scope, node);
+//    ir_build_save_err_ret_addr(irb, scope, node, stack_trace_ptr);
+//}
+
 static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node, LVal lval) {
     assert(node->type == NodeTypeReturnExpr);
 
@@ -2791,9 +2825,13 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node,
 
                 size_t defer_counts[2];
                 ir_count_defers(irb, scope, outer_scope, defer_counts);
-                if (defer_counts[ReturnKindError] > 0) {
+                bool have_err_defers = defer_counts[ReturnKindError] > 0;
+                if (have_err_defers || irb->codegen->have_err_ret_tracing) {
                     IrBasicBlock *err_block = ir_create_basic_block(irb, scope, "ErrRetErr");
                     IrBasicBlock *ok_block = ir_create_basic_block(irb, scope, "ErrRetOk");
+                    if (!have_err_defers) {
+                        ir_gen_defers_for_block(irb, scope, outer_scope, false);
+                    }
 
                     IrInstruction *is_err = ir_build_test_err(irb, scope, node, return_value);
 
@@ -2808,11 +2846,16 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node,
                     IrBasicBlock *ret_stmt_block = ir_create_basic_block(irb, scope, "RetStmt");
 
                     ir_set_cursor_at_end_and_append_block(irb, err_block);
-                    ir_gen_defers_for_block(irb, scope, outer_scope, true);
+                    if (have_err_defers) {
+                        ir_gen_defers_for_block(irb, scope, outer_scope, true);
+                    }
+                    //ir_gen_save_err_ret_addr(irb, scope, node, is_async);
                     ir_build_br(irb, scope, node, ret_stmt_block, is_comptime);
 
                     ir_set_cursor_at_end_and_append_block(irb, ok_block);
-                    ir_gen_defers_for_block(irb, scope, outer_scope, false);
+                    if (have_err_defers) {
+                        ir_gen_defers_for_block(irb, scope, outer_scope, false);
+                    }
                     ir_build_br(irb, scope, node, ret_stmt_block, is_comptime);
 
                     ir_set_cursor_at_end_and_append_block(irb, ret_stmt_block);
@@ -2834,7 +2877,12 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node,
 
                 IrBasicBlock *return_block = ir_create_basic_block(irb, scope, "ErrRetReturn");
                 IrBasicBlock *continue_block = ir_create_basic_block(irb, scope, "ErrRetContinue");
-                IrInstruction *is_comptime = ir_build_const_bool(irb, scope, node, ir_should_inline(irb->exec, scope));
+                IrInstruction *is_comptime;
+                if (ir_should_inline(irb->exec, scope)) {
+                    is_comptime = ir_build_const_bool(irb, scope, node, true);
+                } else {
+                    is_comptime = ir_build_test_comptime(irb, scope, node, is_err_val);
+                }
                 ir_mark_gen(ir_build_cond_br(irb, scope, node, is_err_val, return_block, continue_block, is_comptime));
 
                 ir_set_cursor_at_end_and_append_block(irb, return_block);
@@ -6002,6 +6050,7 @@ static IrInstruction *ir_gen_await_expr(IrBuilder *irb, Scope *parent_scope, Ast
     IrInstruction *undefined_value = ir_build_const_undefined(irb, parent_scope, node);
     IrInstruction *target_promise_type = ir_build_typeof(irb, parent_scope, node, target_inst);
     IrInstruction *promise_result_type = ir_build_promise_result_type(irb, parent_scope, node, target_promise_type);
+    ir_build_await_bookkeeping(irb, parent_scope, node, promise_result_type);
     ir_build_var_decl(irb, parent_scope, node, result_var, promise_result_type, nullptr, undefined_value);
     IrInstruction *my_result_var_ptr = ir_build_var_ptr(irb, parent_scope, node, result_var, false, false);
     ir_build_store_ptr(irb, parent_scope, node, result_ptr_field_ptr, my_result_var_ptr);
@@ -6271,7 +6320,6 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec
     IrInstruction *coro_id;
     IrInstruction *u8_ptr_type;
     IrInstruction *const_bool_false;
-    IrInstruction *coro_result_field_ptr;
     TypeTableEntry *return_type;
     Buf *result_ptr_field_name;
     VariableTableEntry *coro_size_var;
@@ -6325,10 +6373,10 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec
         irb->exec->coro_awaiter_field_ptr = ir_build_field_ptr(irb, scope, node, coro_promise_ptr,
                 awaiter_handle_field_name);
         Buf *result_field_name = buf_create_from_str(RESULT_FIELD_NAME);
-        coro_result_field_ptr = ir_build_field_ptr(irb, scope, node, coro_promise_ptr, result_field_name);
+        irb->exec->coro_result_field_ptr = ir_build_field_ptr(irb, scope, node, coro_promise_ptr, result_field_name);
         result_ptr_field_name = buf_create_from_str(RESULT_PTR_FIELD_NAME);
         irb->exec->coro_result_ptr_field_ptr = ir_build_field_ptr(irb, scope, node, coro_promise_ptr, result_ptr_field_name);
-        ir_build_store_ptr(irb, scope, node, irb->exec->coro_result_ptr_field_ptr, coro_result_field_ptr);
+        ir_build_store_ptr(irb, scope, node, irb->exec->coro_result_ptr_field_ptr, irb->exec->coro_result_field_ptr);
 
 
         irb->exec->coro_early_final = ir_create_basic_block(irb, scope, "CoroEarlyFinal");
@@ -6368,14 +6416,11 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec
         ir_build_unreachable(irb, scope, node);
 
         ir_set_cursor_at_end_and_append_block(irb, irb->exec->coro_normal_final);
-        ir_build_br(irb, scope, node, check_free_block, const_bool_false);
-
-        ir_set_cursor_at_end_and_append_block(irb, irb->exec->coro_final_cleanup_block);
         if (type_has_bits(return_type)) {
             IrInstruction *result_ptr = ir_build_load_ptr(irb, scope, node, irb->exec->coro_result_ptr_field_ptr);
             IrInstruction *result_ptr_as_u8_ptr = ir_build_ptr_cast(irb, scope, node, u8_ptr_type, result_ptr);
             IrInstruction *return_value_ptr_as_u8_ptr = ir_build_ptr_cast(irb, scope, node, u8_ptr_type,
-                    coro_result_field_ptr);
+                    irb->exec->coro_result_field_ptr);
             IrInstruction *return_type_inst = ir_build_const_type(irb, scope, node,
                     fn_entry->type_entry->data.fn.fn_type_id.return_type);
             IrInstruction *size_of_ret_val = ir_build_size_of(irb, scope, node, return_type_inst);
@@ -6383,6 +6428,9 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec
         }
         ir_build_br(irb, scope, node, check_free_block, const_bool_false);
 
+        ir_set_cursor_at_end_and_append_block(irb, irb->exec->coro_final_cleanup_block);
+        ir_build_br(irb, scope, node, check_free_block, const_bool_false);
+
         ir_set_cursor_at_end_and_append_block(irb, check_free_block);
         IrBasicBlock **incoming_blocks = allocate<IrBasicBlock *>(2);
         IrInstruction **incoming_values = allocate<IrInstruction *>(2);
@@ -11405,7 +11453,7 @@ static TypeTableEntry *ir_analyze_instruction_error_return_trace(IrAnalyze *ira,
     FnTableEntry *fn_entry = exec_fn_entry(ira->new_irb.exec);
     TypeTableEntry *ptr_to_stack_trace_type = get_ptr_to_stack_trace_type(ira->codegen);
     TypeTableEntry *nullable_type = get_maybe_type(ira->codegen, ptr_to_stack_trace_type);
-    if (fn_entry == nullptr || !fn_entry->calls_errorable_function || !ira->codegen->have_err_ret_tracing) {
+    if (fn_entry == nullptr || !fn_entry->calls_or_awaits_errorable_fn || !ira->codegen->have_err_ret_tracing) {
         ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base);
         out_val->data.x_maybe = nullptr;
         return nullable_type;
@@ -12085,7 +12133,7 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
 
         TypeTableEntry *return_type = impl_fn->type_entry->data.fn.fn_type_id.return_type;
         if (fn_type_can_fail(&impl_fn->type_entry->data.fn.fn_type_id)) {
-            parent_fn_entry->calls_errorable_function = true;
+            parent_fn_entry->calls_or_awaits_errorable_fn = true;
         }
 
         size_t impl_param_count = impl_fn->type_entry->data.fn.fn_type_id.param_count;
@@ -12111,7 +12159,7 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
     assert(fn_type_id->return_type != nullptr);
     assert(parent_fn_entry != nullptr);
     if (fn_type_can_fail(fn_type_id)) {
-        parent_fn_entry->calls_errorable_function = true;
+        parent_fn_entry->calls_or_awaits_errorable_fn = true;
     }
 
 
@@ -17655,6 +17703,22 @@ static TypeTableEntry *ir_analyze_instruction_promise_result_type(IrAnalyze *ira
     return ira->codegen->builtin_types.entry_type;
 }
 
+static TypeTableEntry *ir_analyze_instruction_await_bookkeeping(IrAnalyze *ira, IrInstructionAwaitBookkeeping *instruction) {
+    TypeTableEntry *promise_result_type = ir_resolve_type(ira, instruction->promise_result_type->other);
+    if (type_is_invalid(promise_result_type))
+        return ira->codegen->builtin_types.entry_invalid;
+
+    FnTableEntry *fn_entry = exec_fn_entry(ira->new_irb.exec);
+    assert(fn_entry != nullptr);
+
+    if (type_can_fail(promise_result_type)) {
+        fn_entry->calls_or_awaits_errorable_fn = true;
+    }
+
+    ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base);
+    out_val->type = ira->codegen->builtin_types.entry_void;
+    return out_val->type;
+}
 
 static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstruction *instruction) {
     switch (instruction->id) {
@@ -17672,6 +17736,7 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi
         case IrInstructionIdErrWrapPayload:
         case IrInstructionIdCast:
             zig_unreachable();
+
         case IrInstructionIdReturn:
             return ir_analyze_instruction_return(ira, (IrInstructionReturn *)instruction);
         case IrInstructionIdConst:
@@ -17890,6 +17955,8 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi
             return ir_analyze_instruction_atomic_rmw(ira, (IrInstructionAtomicRmw *)instruction);
         case IrInstructionIdPromiseResultType:
             return ir_analyze_instruction_promise_result_type(ira, (IrInstructionPromiseResultType *)instruction);
+        case IrInstructionIdAwaitBookkeeping:
+            return ir_analyze_instruction_await_bookkeeping(ira, (IrInstructionAwaitBookkeeping *)instruction);
     }
     zig_unreachable();
 }
@@ -18014,6 +18081,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdCoroResume:
         case IrInstructionIdCoroSave:
         case IrInstructionIdCoroAllocHelper:
+        case IrInstructionIdAwaitBookkeeping:
             return true;
 
         case IrInstructionIdPhi:
src/ir_print.cpp
@@ -1155,6 +1155,12 @@ static void ir_print_atomic_rmw(IrPrint *irp, IrInstructionAtomicRmw *instructio
     fprintf(irp->f, ")");
 }
 
+static void ir_print_await_bookkeeping(IrPrint *irp, IrInstructionAwaitBookkeeping *instruction) {
+    fprintf(irp->f, "@awaitBookkeeping(");
+    ir_print_other_instruction(irp, instruction->promise_result_type);
+    fprintf(irp->f, ")");
+}
+
 static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
     ir_print_prefix(irp, instruction);
     switch (instruction->id) {
@@ -1523,6 +1529,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdPromiseResultType:
             ir_print_promise_result_type(irp, (IrInstructionPromiseResultType *)instruction);
             break;
+        case IrInstructionIdAwaitBookkeeping:
+            ir_print_await_bookkeeping(irp, (IrInstructionAwaitBookkeeping *)instruction);
+            break;
     }
     fprintf(irp->f, "\n");
 }