Commit 8c39cdc89f

Andrew Kelley <superjoe30@gmail.com>
2018-07-04 03:36:16
fix await on early return when return type is struct
previously, await on an early return would try to access the destroyed coroutine frame; now it copies the result into a temporary variable before destroying the coroutine frame
1 parent 1d18688
src/ir.cpp
@@ -6674,7 +6674,10 @@ static IrInstruction *ir_gen_await_expr(IrBuilder *irb, Scope *parent_scope, Ast
     }
     Buf *result_field_name = buf_create_from_str(RESULT_FIELD_NAME);
     IrInstruction *promise_result_ptr = ir_build_field_ptr(irb, parent_scope, node, coro_promise_ptr, result_field_name);
+    // If the type of the result handle_is_ptr then this does not actually perform a load. But we need it to,
+    // because we're about to destroy the memory. So we store it into our result variable.
     IrInstruction *no_suspend_result = ir_build_load_ptr(irb, parent_scope, node, promise_result_ptr);
+    ir_build_store_ptr(irb, parent_scope, node, my_result_var_ptr, no_suspend_result);
     ir_build_cancel(irb, parent_scope, node, target_inst);
     ir_build_br(irb, parent_scope, node, merge_block, const_bool_false);
 
@@ -6696,17 +6699,10 @@ static IrInstruction *ir_gen_await_expr(IrBuilder *irb, Scope *parent_scope, Ast
     ir_mark_gen(ir_build_br(irb, parent_scope, node, irb->exec->coro_final_cleanup_block, const_bool_false));
 
     ir_set_cursor_at_end_and_append_block(irb, resume_block);
-    IrInstruction *yes_suspend_result = ir_build_load_ptr(irb, parent_scope, node, my_result_var_ptr);
     ir_build_br(irb, parent_scope, node, merge_block, const_bool_false);
 
     ir_set_cursor_at_end_and_append_block(irb, merge_block);
-    IrBasicBlock **incoming_blocks = allocate<IrBasicBlock *>(2);
-    IrInstruction **incoming_values = allocate<IrInstruction *>(2);
-    incoming_blocks[0] = resume_block;
-    incoming_values[0] = yes_suspend_result;
-    incoming_blocks[1] = no_suspend_block;
-    incoming_values[1] = no_suspend_result;
-    return ir_build_phi(irb, parent_scope, node, 2, incoming_blocks, incoming_values);
+    return ir_build_load_ptr(irb, parent_scope, node, my_result_var_ptr);
 }
 
 static IrInstruction *ir_gen_suspend(IrBuilder *irb, Scope *parent_scope, AstNode *node) {
test/cases/coroutine_await_struct.zig
@@ -0,0 +1,47 @@
+const std = @import("std");
+const builtin = @import("builtin");
+const assert = std.debug.assert;
+
+const Foo = struct {
+    x: i32,
+};
+
+var await_a_promise: promise = undefined;
+var await_final_result = Foo{ .x = 0 };
+
+test "coroutine await struct" {
+    var da = std.heap.DirectAllocator.init();
+    defer da.deinit();
+
+    await_seq('a');
+    const p = async<&da.allocator> await_amain() catch unreachable;
+    await_seq('f');
+    resume await_a_promise;
+    await_seq('i');
+    assert(await_final_result.x == 1234);
+    assert(std.mem.eql(u8, await_points, "abcdefghi"));
+}
+async fn await_amain() void {
+    await_seq('b');
+    const p = async await_another() catch unreachable;
+    await_seq('e');
+    await_final_result = await p;
+    await_seq('h');
+}
+async fn await_another() Foo {
+    await_seq('c');
+    suspend |p| {
+        await_seq('d');
+        await_a_promise = p;
+    }
+    await_seq('g');
+    return Foo{ .x = 1234 };
+}
+
+var await_points = []u8{0} ** "abcdefghi".len;
+var await_seq_index: usize = 0;
+
+fn await_seq(c: u8) void {
+    await_points[await_seq_index] = c;
+    await_seq_index += 1;
+}
test/cases/coroutines.zig
@@ -116,14 +116,14 @@ test "coroutine await early return" {
     defer da.deinit();
 
     early_seq('a');
-    const p = async<&da.allocator> early_amain() catch unreachable;
+    const p = async<&da.allocator> early_amain() catch @panic("out of memory");
     early_seq('f');
     assert(early_final_result == 1234);
     assert(std.mem.eql(u8, early_points, "abcdef"));
 }
 async fn early_amain() void {
     early_seq('b');
-    const p = async early_another() catch unreachable;
+    const p = async early_another() catch @panic("out of memory");
     early_seq('d');
     early_final_result = await p;
     early_seq('e');
test/behavior.zig
@@ -18,6 +18,7 @@ comptime {
     _ = @import("cases/cast.zig");
     _ = @import("cases/const_slice_child.zig");
     _ = @import("cases/coroutines.zig");
+    _ = @import("cases/coroutine_await_struct.zig");
     _ = @import("cases/defer.zig");
     _ = @import("cases/enum.zig");
     _ = @import("cases/enum_with_members.zig");