Commit e220812f2f

Andrew Kelley <andrew@ziglang.org>
2019-07-24 08:59:51
implement local variables in async functions
1 parent 19ee495
Changed files (3)
src
test
stage1
src/analyze.cpp
@@ -1911,11 +1911,32 @@ static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) {
         } else {
             param_name = buf_sprintf("arg%" ZIG_PRI_usize "", arg_i);
         }
-        ZigType *param_type = param_info[arg_i].type;
+        ZigType *param_type = param_info->type;
         field_names.append(buf_ptr(param_name));
         field_types.append(param_type);
     }
 
+    for (size_t alloca_i = 0; alloca_i < fn->alloca_gen_list.length; alloca_i += 1) {
+        IrInstructionAllocaGen *instruction = fn->alloca_gen_list.at(alloca_i);
+        ZigType *ptr_type = instruction->base.value.type;
+        assert(ptr_type->id == ZigTypeIdPointer);
+        ZigType *child_type = ptr_type->data.pointer.child_type;
+        if (!type_has_bits(child_type))
+            continue;
+        if (instruction->base.ref_count == 0)
+            continue;
+        if (instruction->base.value.special != ConstValSpecialRuntime) {
+            if (const_ptr_pointee(nullptr, g, &instruction->base.value, nullptr)->special !=
+                    ConstValSpecialRuntime)
+            {
+                continue;
+            }
+        }
+        field_names.append(instruction->name_hint);
+        field_types.append(child_type);
+    }
+
+
     assert(field_names.length == field_types.length);
     frame_type->data.frame.locals_struct = get_struct_type(g, buf_ptr(&frame_type->name),
             field_names.items, field_types.items, field_names.length);
src/codegen.cpp
@@ -6174,6 +6174,7 @@ static void do_code_gen(CodeGen *g) {
         clear_debug_source_node(g);
 
         bool is_async = fn_is_async(fn_table_entry);
+        size_t async_var_index = coro_arg_start + (type_has_bits(fn_type_id->return_type) ? 1 : 0);
 
         if (want_sret || is_async) {
             g->cur_ret_ptr = LLVMGetParam(fn, 0);
@@ -6206,25 +6207,27 @@ static void do_code_gen(CodeGen *g) {
             g->cur_err_ret_trace_val_stack = nullptr;
         }
 
-        // allocate temporary stack data
-        for (size_t alloca_i = 0; alloca_i < fn_table_entry->alloca_gen_list.length; alloca_i += 1) {
-            IrInstructionAllocaGen *instruction = fn_table_entry->alloca_gen_list.at(alloca_i);
-            ZigType *ptr_type = instruction->base.value.type;
-            assert(ptr_type->id == ZigTypeIdPointer);
-            ZigType *child_type = ptr_type->data.pointer.child_type;
-            if (!type_has_bits(child_type))
-                continue;
-            if (instruction->base.ref_count == 0)
-                continue;
-            if (instruction->base.value.special != ConstValSpecialRuntime) {
-                if (const_ptr_pointee(nullptr, g, &instruction->base.value, nullptr)->special !=
-                        ConstValSpecialRuntime)
-                {
+        if (!is_async) {
+            // allocate temporary stack data
+            for (size_t alloca_i = 0; alloca_i < fn_table_entry->alloca_gen_list.length; alloca_i += 1) {
+                IrInstructionAllocaGen *instruction = fn_table_entry->alloca_gen_list.at(alloca_i);
+                ZigType *ptr_type = instruction->base.value.type;
+                assert(ptr_type->id == ZigTypeIdPointer);
+                ZigType *child_type = ptr_type->data.pointer.child_type;
+                if (!type_has_bits(child_type))
                     continue;
+                if (instruction->base.ref_count == 0)
+                    continue;
+                if (instruction->base.value.special != ConstValSpecialRuntime) {
+                    if (const_ptr_pointee(nullptr, g, &instruction->base.value, nullptr)->special !=
+                            ConstValSpecialRuntime)
+                    {
+                        continue;
+                    }
                 }
+                instruction->base.llvm_value = build_alloca(g, child_type, instruction->name_hint,
+                        get_ptr_align(g, ptr_type));
             }
-            instruction->base.llvm_value = build_alloca(g, child_type, instruction->name_hint,
-                    get_ptr_align(g, ptr_type));
         }
 
         ZigType *import = get_scope_import(&fn_table_entry->fndef_scope->base);
@@ -6263,9 +6266,9 @@ static void do_code_gen(CodeGen *g) {
                 fn_walk_var.data.vars.var = var;
                 iter_function_params_c_abi(g, fn_table_entry->type_entry, &fn_walk_var, var->src_arg_index);
             } else if (is_async) {
-                size_t ret_1_or_0 = type_has_bits(fn_type_id->return_type) ? 1 : 0;
-                var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr,
-                        coro_arg_start + ret_1_or_0 + var_i, "");
+                var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, async_var_index,
+                        buf_ptr(&var->name));
+                async_var_index += 1;
                 if (var->decl_node) {
                     var->di_loc_var = ZigLLVMCreateAutoVariable(g->dbuilder, get_di_scope(g, var->parent_scope),
                         buf_ptr(&var->name), import->data.structure.root_struct->di_file,
@@ -6299,6 +6302,29 @@ static void do_code_gen(CodeGen *g) {
             }
         }
 
+        if (is_async) {
+            for (size_t alloca_i = 0; alloca_i < fn_table_entry->alloca_gen_list.length; alloca_i += 1) {
+                IrInstructionAllocaGen *instruction = fn_table_entry->alloca_gen_list.at(alloca_i);
+                ZigType *ptr_type = instruction->base.value.type;
+                assert(ptr_type->id == ZigTypeIdPointer);
+                ZigType *child_type = ptr_type->data.pointer.child_type;
+                if (!type_has_bits(child_type))
+                    continue;
+                if (instruction->base.ref_count == 0)
+                    continue;
+                if (instruction->base.value.special != ConstValSpecialRuntime) {
+                    if (const_ptr_pointee(nullptr, g, &instruction->base.value, nullptr)->special !=
+                            ConstValSpecialRuntime)
+                    {
+                        continue;
+                    }
+                }
+                instruction->base.llvm_value = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, async_var_index,
+                        instruction->name_hint);
+                async_var_index += 1;
+            }
+        }
+
         // finishing error return trace setup. we have to do this after all the allocas.
         if (have_err_ret_trace_stack) {
             ZigType *usize = g->builtin_types.entry_usize;
test/stage1/behavior/coroutines.zig
@@ -47,6 +47,36 @@ test "suspend at end of function" {
     };
     S.doTheTest();
 }
+
+test "local variable in async function" {
+    const S = struct {
+        var x: i32 = 0;
+
+        fn doTheTest() void {
+            expect(x == 0);
+            const p = async add(1, 2);
+            expect(x == 0);
+            resume p;
+            expect(x == 0);
+            resume p;
+            expect(x == 0);
+            resume p;
+            expect(x == 3);
+        }
+
+        fn add(a: i32, b: i32) void {
+            var accum: i32 = 0;
+            suspend;
+            accum += a;
+            suspend;
+            accum += b;
+            suspend;
+            x = accum;
+        }
+    };
+    S.doTheTest();
+}
+
 //test "coroutine suspend, resume" {
 //    seq('a');
 //    const p = try async<allocator> testAsyncSeq();