Commit b3f4182ca1

Andrew Kelley <superjoe30@gmail.com>
2018-07-27 04:26:00
coroutines have 3 more bits of atomic state
1 parent 2cbad36
Changed files (3)
src/all_types.hpp
@@ -3245,7 +3245,7 @@ static const size_t stack_trace_ptr_count = 30;
 #define RESULT_FIELD_NAME "result"
 #define ASYNC_ALLOC_FIELD_NAME "allocFn"
 #define ASYNC_FREE_FIELD_NAME "freeFn"
-#define AWAITER_HANDLE_FIELD_NAME "awaiter_handle"
+#define ATOMIC_STATE_FIELD_NAME "atomic_state"
 // these point to data belonging to the awaiter
 #define ERR_RET_TRACE_PTR_FIELD_NAME "err_ret_trace_ptr"
 #define RESULT_PTR_FIELD_NAME "result_ptr"
src/analyze.cpp
@@ -519,11 +519,11 @@ TypeTableEntry *get_promise_frame_type(CodeGen *g, TypeTableEntry *return_type)
         return return_type->promise_frame_parent;
     }
 
-    TypeTableEntry *awaiter_handle_type = get_optional_type(g, g->builtin_types.entry_promise);
+    TypeTableEntry *atomic_state_type = g->builtin_types.entry_usize;
     TypeTableEntry *result_ptr_type = get_pointer_to_type(g, return_type, false);
 
     ZigList<const char *> field_names = {};
-    field_names.append(AWAITER_HANDLE_FIELD_NAME);
+    field_names.append(ATOMIC_STATE_FIELD_NAME);
     field_names.append(RESULT_FIELD_NAME);
     field_names.append(RESULT_PTR_FIELD_NAME);
     if (g->have_err_ret_tracing) {
@@ -533,7 +533,7 @@ TypeTableEntry *get_promise_frame_type(CodeGen *g, TypeTableEntry *return_type)
     }
 
     ZigList<TypeTableEntry *> field_types = {};
-    field_types.append(awaiter_handle_type);
+    field_types.append(atomic_state_type);
     field_types.append(return_type);
     field_types.append(result_ptr_type);
     if (g->have_err_ret_tracing) {
@@ -6228,7 +6228,12 @@ uint32_t get_abi_alignment(CodeGen *g, TypeTableEntry *type_entry) {
     } else if (type_entry->id == TypeTableEntryIdOpaque) {
         return 1;
     } else {
-        return LLVMABIAlignmentOfType(g->target_data_ref, type_entry->type_ref);
+        uint32_t llvm_alignment = LLVMABIAlignmentOfType(g->target_data_ref, type_entry->type_ref);
+        // promises have at least alignment 8 so that we can have 3 extra bits when doing atomicrmw
+        if (type_entry->id == TypeTableEntryIdPromise && llvm_alignment < 8) {
+            return 8;
+        }
+        return llvm_alignment;
     }
 }
 
src/ir.cpp
@@ -3097,19 +3097,47 @@ static IrInstruction *ir_gen_async_return(IrBuilder *irb, Scope *scope, AstNode
         return return_inst;
     }
 
+    IrBasicBlock *canceled_block = ir_create_basic_block(irb, scope, "Canceled");
+    IrBasicBlock *not_canceled_block = ir_create_basic_block(irb, scope, "NotCanceled");
+    IrBasicBlock *suspended_block = ir_create_basic_block(irb, scope, "Suspended");
+    IrBasicBlock *not_suspended_block = ir_create_basic_block(irb, scope, "NotSuspended");
+
     ir_build_store_ptr(irb, scope, node, irb->exec->coro_result_field_ptr, return_value);
-    IrInstruction *promise_type_val = ir_build_const_type(irb, scope, node,
-            get_optional_type(irb->codegen, irb->codegen->builtin_types.entry_promise));
-    // TODO replace replacement_value with @intToPtr(?promise, 0x1) when it doesn't crash zig
-    IrInstruction *replacement_value = irb->exec->coro_handle;
-    IrInstruction *maybe_await_handle = ir_build_atomic_rmw(irb, scope, node,
-            promise_type_val, irb->exec->coro_awaiter_field_ptr, nullptr, replacement_value, nullptr,
-            AtomicRmwOp_xchg, AtomicOrderSeqCst);
-    ir_build_store_ptr(irb, scope, node, irb->exec->await_handle_var_ptr, maybe_await_handle);
-    IrInstruction *is_non_null = ir_build_test_nonnull(irb, scope, node, maybe_await_handle);
+    IrInstruction *usize_type_val = ir_build_const_type(irb, scope, node, irb->codegen->builtin_types.entry_usize);
+    IrInstruction *replacement_value = ir_build_const_usize(irb, scope, node, 0xa); // 0b1010
+    IrInstruction *prev_atomic_value = ir_build_atomic_rmw(irb, scope, node,
+            usize_type_val, irb->exec->coro_awaiter_field_ptr, nullptr, replacement_value, nullptr,
+            AtomicRmwOp_or, AtomicOrderSeqCst);
+
+    IrInstruction *zero = ir_build_const_usize(irb, scope, node, 0);
     IrInstruction *is_comptime = ir_build_const_bool(irb, scope, node, false);
-    return ir_build_cond_br(irb, scope, node, is_non_null, irb->exec->coro_normal_final, irb->exec->coro_early_final,
-            is_comptime);
+    IrInstruction *is_canceled_mask = ir_build_const_usize(irb, scope, node, 0x1); // 0b001
+    IrInstruction *is_canceled_value = ir_build_bin_op(irb, scope, node, IrBinOpBinAnd, prev_atomic_value, is_canceled_mask, false);
+    IrInstruction *is_canceled_bool = ir_build_bin_op(irb, scope, node, IrBinOpCmpNotEq, is_canceled_value, zero, false);
+    ir_build_cond_br(irb, scope, node, is_canceled_bool, canceled_block, not_canceled_block, is_comptime);
+
+    ir_set_cursor_at_end_and_append_block(irb, canceled_block);
+    ir_mark_gen(ir_build_br(irb, scope, node, irb->exec->coro_final_cleanup_block, is_comptime));
+
+    ir_set_cursor_at_end_and_append_block(irb, not_canceled_block);
+    IrInstruction *inverted_ptr_mask = ir_build_const_usize(irb, scope, node, 0x7); // 0b111
+    IrInstruction *is_suspended_value = ir_build_bin_op(irb, scope, node, IrBinOpBinAnd, prev_atomic_value, inverted_ptr_mask, false);
+    IrInstruction *is_suspended_bool = ir_build_bin_op(irb, scope, node, IrBinOpCmpNotEq, is_suspended_value, zero, false);
+    ir_build_cond_br(irb, scope, node, is_suspended_bool, suspended_block, not_suspended_block, is_comptime);
+
+    ir_set_cursor_at_end_and_append_block(irb, suspended_block);
+    ir_build_unreachable(irb, scope, node);
+
+    ir_set_cursor_at_end_and_append_block(irb, not_suspended_block);
+    IrInstruction *ptr_mask = ir_build_un_op(irb, scope, node, IrUnOpBinNot, inverted_ptr_mask); // 0b111...000
+    IrInstruction *await_handle_addr = ir_build_bin_op(irb, scope, node, IrBinOpBinAnd, prev_atomic_value, ptr_mask, false);
+    IrInstruction *promise_type_val = ir_build_const_type(irb, scope, node, irb->codegen->builtin_types.entry_promise);
+    // if we ever add null checking safety to the ptrtoint instruction, it needs to be disabled here
+    IrInstruction *await_handle = ir_build_int_to_ptr(irb, scope, node, promise_type_val, await_handle_addr);
+    ir_build_store_ptr(irb, scope, node, irb->exec->await_handle_var_ptr, await_handle);
+    IrInstruction *is_non_null = ir_build_bin_op(irb, scope, node, IrBinOpCmpNotEq, await_handle_addr, zero, false);
+    return ir_build_cond_br(irb, scope, node, is_non_null, irb->exec->coro_normal_final,
+            irb->exec->coro_early_final, is_comptime);
     // the above blocks are rendered by ir_gen after the rest of codegen
 }
 
@@ -6708,9 +6736,9 @@ static IrInstruction *ir_gen_await_expr(IrBuilder *irb, Scope *parent_scope, Ast
         ir_build_store_ptr(irb, parent_scope, node, err_ret_trace_ptr_field_ptr, err_ret_trace_ptr);
     }
 
-    Buf *awaiter_handle_field_name = buf_create_from_str(AWAITER_HANDLE_FIELD_NAME);
-    IrInstruction *awaiter_field_ptr = ir_build_field_ptr(irb, parent_scope, node, coro_promise_ptr,
-            awaiter_handle_field_name);
+    Buf *atomic_state_field_name = buf_create_from_str(ATOMIC_STATE_FIELD_NAME);
+    IrInstruction *atomic_state_ptr = ir_build_field_ptr(irb, parent_scope, node, coro_promise_ptr,
+            atomic_state_field_name);
 
     IrInstruction *const_bool_false = ir_build_const_bool(irb, parent_scope, node, false);
     VariableTableEntry *result_var = ir_create_var(irb, node, parent_scope, nullptr,
@@ -6723,12 +6751,16 @@ static IrInstruction *ir_gen_await_expr(IrBuilder *irb, Scope *parent_scope, Ast
     IrInstruction *my_result_var_ptr = ir_build_var_ptr(irb, parent_scope, node, result_var);
     ir_build_store_ptr(irb, parent_scope, node, result_ptr_field_ptr, my_result_var_ptr);
     IrInstruction *save_token = ir_build_coro_save(irb, parent_scope, node, irb->exec->coro_handle);
-    IrInstruction *promise_type_val = ir_build_const_type(irb, parent_scope, node,
-            get_optional_type(irb->codegen, irb->codegen->builtin_types.entry_promise));
-    IrInstruction *maybe_await_handle = ir_build_atomic_rmw(irb, parent_scope, node, 
-            promise_type_val, awaiter_field_ptr, nullptr, irb->exec->coro_handle, nullptr,
-            AtomicRmwOp_xchg, AtomicOrderSeqCst);
-    IrInstruction *is_non_null = ir_build_test_nonnull(irb, parent_scope, node, maybe_await_handle);
+    IrInstruction *usize_type_val = ir_build_const_type(irb, parent_scope, node, irb->codegen->builtin_types.entry_usize);
+    IrInstruction *coro_handle_addr = ir_build_ptr_to_int(irb, parent_scope, node, irb->exec->coro_handle);
+    IrInstruction *prev_atomic_value = ir_build_atomic_rmw(irb, parent_scope, node, 
+            usize_type_val, atomic_state_ptr, nullptr, coro_handle_addr, nullptr,
+            AtomicRmwOp_or, AtomicOrderSeqCst);
+    IrInstruction *zero = ir_build_const_usize(irb, parent_scope, node, 0);
+    IrInstruction *inverted_ptr_mask = ir_build_const_usize(irb, parent_scope, node, 0x7); // 0b111
+    IrInstruction *ptr_mask = ir_build_un_op(irb, parent_scope, node, IrUnOpBinNot, inverted_ptr_mask); // 0b111...000
+    IrInstruction *await_handle_addr = ir_build_bin_op(irb, parent_scope, node, IrBinOpBinAnd, prev_atomic_value, ptr_mask, false);
+    IrInstruction *is_non_null = ir_build_bin_op(irb, parent_scope, node, IrBinOpCmpNotEq, await_handle_addr, zero, false);
     IrBasicBlock *yes_suspend_block = ir_create_basic_block(irb, parent_scope, "YesSuspend");
     IrBasicBlock *no_suspend_block = ir_create_basic_block(irb, parent_scope, "NoSuspend");
     IrBasicBlock *merge_block = ir_create_basic_block(irb, parent_scope, "MergeSuspend");
@@ -7087,10 +7119,11 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec
         IrInstruction *coro_mem_ptr = ir_build_ptr_cast(irb, coro_scope, node, u8_ptr_type, maybe_coro_mem_ptr);
         irb->exec->coro_handle = ir_build_coro_begin(irb, coro_scope, node, coro_id, coro_mem_ptr);
 
-        Buf *awaiter_handle_field_name = buf_create_from_str(AWAITER_HANDLE_FIELD_NAME);
+        Buf *atomic_state_field_name = buf_create_from_str(ATOMIC_STATE_FIELD_NAME);
         irb->exec->coro_awaiter_field_ptr = ir_build_field_ptr(irb, scope, node, coro_promise_ptr,
-                awaiter_handle_field_name);
-        ir_build_store_ptr(irb, scope, node, irb->exec->coro_awaiter_field_ptr, null_value);
+                atomic_state_field_name);
+        IrInstruction *zero = ir_build_const_usize(irb, scope, node, 0);
+        ir_build_store_ptr(irb, scope, node, irb->exec->coro_awaiter_field_ptr, zero);
         Buf *result_field_name = buf_create_from_str(RESULT_FIELD_NAME);
         irb->exec->coro_result_field_ptr = ir_build_field_ptr(irb, scope, node, coro_promise_ptr, result_field_name);
         result_ptr_field_name = buf_create_from_str(RESULT_PTR_FIELD_NAME);
@@ -7108,7 +7141,6 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec
             // coordinate with builtin.zig
             Buf *index_name = buf_create_from_str("index");
             IrInstruction *index_ptr = ir_build_field_ptr(irb, scope, node, err_ret_trace_ptr, index_name);
-            IrInstruction *zero = ir_build_const_usize(irb, scope, node, 0);
             ir_build_store_ptr(irb, scope, node, index_ptr, zero);
 
             Buf *instruction_addresses_name = buf_create_from_str("instruction_addresses");