Commit fcadeb50c0

Andrew Kelley <andrew@ziglang.org>
2019-07-22 20:36:14
fix multiple coroutines existing clobbering each other
1 parent 650e07e
src/analyze.cpp
@@ -1865,7 +1865,8 @@ static Error resolve_union_type(CodeGen *g, ZigType *union_type) {
 }
 
 static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) {
-    assert(frame_type->data.frame.locals_struct == nullptr);
+    if (frame_type->data.frame.locals_struct != nullptr)
+        return ErrorNone;
 
     ZigFn *fn = frame_type->data.frame.fn;
     switch (fn->anal_state) {
@@ -3824,6 +3825,15 @@ static void analyze_fn_ir(CodeGen *g, ZigFn *fn_table_entry, AstNode *return_typ
     }
 
     fn_table_entry->anal_state = FnAnalStateComplete;
+
+    if (fn_table_entry->resume_blocks.length != 0) {
+        ZigType *frame_type = get_coro_frame_type(g, fn_table_entry);
+        Error err;
+        if ((err = type_resolve(g, frame_type, ResolveStatusSizeKnown))) {
+            fn_table_entry->anal_state = FnAnalStateInvalid;
+            return;
+        }
+    }
 }
 
 static void analyze_fn_body(CodeGen *g, ZigFn *fn_table_entry) {
@@ -7050,18 +7060,12 @@ static void resolve_llvm_types_array(CodeGen *g, ZigType *type) {
             debug_align_in_bits, get_llvm_di_type(g, elem_type), (int)type->data.array.len);
 }
 
-void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn) {
-    if (fn_type->llvm_di_type != nullptr) {
-        if (fn != nullptr) {
-            fn->raw_type_ref = fn_type->data.fn.raw_type_ref;
-            fn->raw_di_type = fn_type->data.fn.raw_di_type;
-        }
-        return;
-    }
+static void resolve_llvm_types_fn_type(CodeGen *g, ZigType *fn_type) {
+    if (fn_type->llvm_di_type != nullptr) return;
 
     FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
     bool first_arg_return = want_first_arg_sret(g, fn_type_id);
-    bool is_async = fn_type_id->cc == CallingConventionAsync || (fn != nullptr && fn->resume_blocks.length != 0);
+    bool is_async = fn_type_id->cc == CallingConventionAsync;
     bool is_c_abi = fn_type_id->cc == CallingConventionC;
     bool prefix_arg_error_return_trace = g->have_err_ret_tracing && fn_type_can_fail(fn_type_id);
     // +1 for maybe making the first argument the return value
@@ -7100,7 +7104,11 @@ void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn) {
     if (is_async) {
         fn_type->data.fn.gen_param_info = allocate<FnGenParamInfo>(1);
 
-        ZigType *frame_type = (fn == nullptr) ? g->builtin_types.entry_frame_header : get_coro_frame_type(g, fn);
+        ZigType *frame_type = g->builtin_types.entry_frame_header;
+        Error err;
+        if ((err = type_resolve(g, frame_type, ResolveStatusSizeKnown))) {
+            zig_unreachable();
+        }
         ZigType *ptr_type = get_pointer_to_type(g, frame_type, false);
         gen_param_types.append(get_llvm_type(g, ptr_type));
         param_di_types.append(get_llvm_di_type(g, ptr_type));
@@ -7150,12 +7158,7 @@ void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn) {
     for (size_t i = 0; i < gen_param_types.length; i += 1) {
         assert(gen_param_types.items[i] != nullptr);
     }
-    if (fn != nullptr) {
-        fn->raw_type_ref = LLVMFunctionType(get_llvm_type(g, gen_return_type),
-                gen_param_types.items, (unsigned int)gen_param_types.length, fn_type_id->is_var_args);
-        fn->raw_di_type = ZigLLVMCreateSubroutineType(g->dbuilder, param_di_types.items, (int)param_di_types.length, 0);
-        return;
-    }
+
     fn_type->data.fn.raw_type_ref = LLVMFunctionType(get_llvm_type(g, gen_return_type),
             gen_param_types.items, (unsigned int)gen_param_types.length, fn_type_id->is_var_args);
     fn_type->llvm_type = LLVMPointerType(fn_type->data.fn.raw_type_ref, 0);
@@ -7165,6 +7168,35 @@ void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn) {
             LLVMABIAlignmentOfType(g->target_data_ref, fn_type->llvm_type), "");
 }
 
+void resolve_llvm_types_fn(CodeGen *g, ZigFn *fn) {
+    if (fn->raw_di_type != nullptr) return;
+
+    ZigType *fn_type = fn->type_entry;
+    FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
+    bool cc_async = fn_type_id->cc == CallingConventionAsync;
+    bool inferred_async = fn->resume_blocks.length != 0;
+    bool is_async = cc_async || inferred_async;
+    if (!is_async) {
+        resolve_llvm_types_fn_type(g, fn_type);
+        fn->raw_type_ref = fn_type->data.fn.raw_type_ref;
+        fn->raw_di_type = fn_type->data.fn.raw_di_type;
+        return;
+    }
+
+    ZigType *gen_return_type = g->builtin_types.entry_usize;
+    ZigList<ZigLLVMDIType *> param_di_types = {};
+    // first "parameter" is return value
+    param_di_types.append(get_llvm_di_type(g, gen_return_type));
+
+    ZigType *frame_type = get_coro_frame_type(g, fn);
+    ZigType *ptr_type = get_pointer_to_type(g, frame_type, false);
+    LLVMTypeRef gen_param_type = get_llvm_type(g, ptr_type);
+    param_di_types.append(get_llvm_di_type(g, ptr_type));
+
+    fn->raw_type_ref = LLVMFunctionType(get_llvm_type(g, gen_return_type), &gen_param_type, 1, false);
+    fn->raw_di_type = ZigLLVMCreateSubroutineType(g->dbuilder, param_di_types.items, (int)param_di_types.length, 0);
+}
+
 static void resolve_llvm_types_anyerror(CodeGen *g) {
     ZigType *entry = g->builtin_types.entry_global_error_set;
     entry->llvm_type = get_llvm_type(g, g->err_tag_type);
@@ -7241,7 +7273,7 @@ static void resolve_llvm_types(CodeGen *g, ZigType *type, ResolveStatus wanted_r
         case ZigTypeIdArray:
             return resolve_llvm_types_array(g, type);
         case ZigTypeIdFn:
-            return resolve_llvm_types_fn(g, type, nullptr);
+            return resolve_llvm_types_fn_type(g, type);
         case ZigTypeIdErrorSet: {
             if (type->llvm_di_type != nullptr) return;
 
src/analyze.hpp
@@ -247,6 +247,6 @@ void src_assert(bool ok, AstNode *source_node);
 bool is_container(ZigType *type_entry);
 ConstExprValue *analyze_const_value(CodeGen *g, Scope *scope, AstNode *node, ZigType *type_entry, Buf *type_name);
 
-void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn);
+void resolve_llvm_types_fn(CodeGen *g, ZigFn *fn);
 
 #endif
src/codegen.cpp
@@ -371,10 +371,12 @@ static LLVMValueRef fn_llvm_value(CodeGen *g, ZigFn *fn_table_entry) {
         symbol_name = buf_sprintf("\x01_%s", buf_ptr(symbol_name));
     }
 
+    bool is_async = fn_table_entry->resume_blocks.length != 0 || cc == CallingConventionAsync;
+
 
     ZigType *fn_type = fn_table_entry->type_entry;
     // Make the raw_type_ref populated
-    resolve_llvm_types_fn(g, fn_type, fn_table_entry);
+    resolve_llvm_types_fn(g, fn_table_entry);
     LLVMTypeRef fn_llvm_type = fn_table_entry->raw_type_ref;
     if (fn_table_entry->body_node == nullptr) {
         LLVMValueRef existing_llvm_fn = LLVMGetNamedFunction(g->module, buf_ptr(symbol_name));
@@ -397,7 +399,7 @@ static LLVMValueRef fn_llvm_value(CodeGen *g, ZigFn *fn_table_entry) {
                 assert(entry->value->id == TldIdFn);
                 TldFn *tld_fn = reinterpret_cast<TldFn *>(entry->value);
                 // Make the raw_type_ref populated
-                resolve_llvm_types_fn(g, tld_fn->fn_entry->type_entry, tld_fn->fn_entry);
+                resolve_llvm_types_fn(g, tld_fn->fn_entry);
                 tld_fn->fn_entry->llvm_value = LLVMAddFunction(g->module, buf_ptr(symbol_name),
                         tld_fn->fn_entry->raw_type_ref);
                 fn_table_entry->llvm_value = LLVMConstBitCast(tld_fn->fn_entry->llvm_value,
@@ -517,18 +519,22 @@ static LLVMValueRef fn_llvm_value(CodeGen *g, ZigFn *fn_table_entry) {
         init_gen_i = 1;
     }
 
-    // set parameter attributes
-    FnWalk fn_walk = {};
-    fn_walk.id = FnWalkIdAttrs;
-    fn_walk.data.attrs.fn = fn_table_entry;
-    fn_walk.data.attrs.gen_i = init_gen_i;
-    walk_function_params(g, fn_type, &fn_walk);
+    if (is_async) {
+        addLLVMArgAttr(fn_table_entry->llvm_value, 0, "nonnull");
+    } else {
+        // set parameter attributes
+        FnWalk fn_walk = {};
+        fn_walk.id = FnWalkIdAttrs;
+        fn_walk.data.attrs.fn = fn_table_entry;
+        fn_walk.data.attrs.gen_i = init_gen_i;
+        walk_function_params(g, fn_type, &fn_walk);
 
-    uint32_t err_ret_trace_arg_index = get_err_ret_trace_arg_index(g, fn_table_entry);
-    if (err_ret_trace_arg_index != UINT32_MAX) {
-        // Error return trace memory is in the stack, which is impossible to be at address 0
-        // on any architecture.
-        addLLVMArgAttr(fn_table_entry->llvm_value, (unsigned)err_ret_trace_arg_index, "nonnull");
+        uint32_t err_ret_trace_arg_index = get_err_ret_trace_arg_index(g, fn_table_entry);
+        if (err_ret_trace_arg_index != UINT32_MAX) {
+            // Error return trace memory is in the stack, which is impossible to be at address 0
+            // on any architecture.
+            addLLVMArgAttr(fn_table_entry->llvm_value, (unsigned)err_ret_trace_arg_index, "nonnull");
+        }
     }
 
     return fn_table_entry->llvm_value;
@@ -6254,14 +6260,21 @@ static void do_code_gen(CodeGen *g) {
             } else if (is_c_abi) {
                 fn_walk_var.data.vars.var = var;
                 iter_function_params_c_abi(g, fn_table_entry->type_entry, &fn_walk_var, var->src_arg_index);
+            } else if (is_async) {
+                var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_arg_start + var_i, "");
+                if (var->decl_node) {
+                    var->di_loc_var = ZigLLVMCreateAutoVariable(g->dbuilder, get_di_scope(g, var->parent_scope),
+                        buf_ptr(&var->name), import->data.structure.root_struct->di_file,
+                        (unsigned)(var->decl_node->line + 1),
+                        get_llvm_di_type(g, var->var_type), !g->strip_debug_symbols, 0);
+                    gen_var_debug_decl(g, var);
+                }
             } else {
                 ZigType *gen_type;
                 FnGenParamInfo *gen_info = &fn_table_entry->type_entry->data.fn.gen_param_info[var->src_arg_index];
                 assert(gen_info->gen_index != SIZE_MAX);
 
-                if (is_async) {
-                    var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_arg_start + var_i, "");
-                } else if (handle_is_ptr(var->var_type)) {
+                if (handle_is_ptr(var->var_type)) {
                     if (gen_info->is_byval) {
                         gen_type = var->var_type;
                     } else {
@@ -6307,16 +6320,7 @@ static void do_code_gen(CodeGen *g) {
             gen_store(g, LLVMConstInt(usize->llvm_type, stack_trace_ptr_count, false), len_field_ptr, get_pointer_to_type(g, usize, false));
         }
 
-        // create debug variable declarations for parameters
-        // rely on the first variables in the variable_list being parameters.
-        FnWalk fn_walk_init = {};
-        fn_walk_init.id = FnWalkIdInits;
-        fn_walk_init.data.inits.fn = fn_table_entry;
-        fn_walk_init.data.inits.llvm_fn = fn;
-        fn_walk_init.data.inits.gen_i = gen_i_init;
-        walk_function_params(g, fn_table_entry->type_entry, &fn_walk_init);
-
-        if (fn_table_entry->resume_blocks.length != 0) {
+        if (is_async) {
             if (!g->strip_debug_symbols) {
                 AstNode *source_node = fn_table_entry->proto_node;
                 ZigLLVMSetCurrentDebugLocation(g->builder, (int)source_node->line + 1,
@@ -6354,8 +6358,18 @@ static void do_code_gen(CodeGen *g) {
                 LLVMValueRef case_value = LLVMConstInt(usize_type_ref, resume_i + 2, false);
                 LLVMAddCase(switch_instr, case_value, fn_table_entry->resume_blocks.at(resume_i)->llvm_block);
             }
+        } else {
+            // create debug variable declarations for parameters
+            // rely on the first variables in the variable_list being parameters.
+            FnWalk fn_walk_init = {};
+            fn_walk_init.id = FnWalkIdInits;
+            fn_walk_init.data.inits.fn = fn_table_entry;
+            fn_walk_init.data.inits.llvm_fn = fn;
+            fn_walk_init.data.inits.gen_i = gen_i_init;
+            walk_function_params(g, fn_table_entry->type_entry, &fn_walk_init);
         }
 
+
         ir_render(g, fn_table_entry);
 
     }
test/stage1/behavior/coroutines.zig
@@ -29,6 +29,24 @@ fn simpleAsyncFnWithArg(delta: i32) void {
     suspend;
     global_y += delta;
 }
+
+test "suspend at end of function" {
+    const S = struct {
+        var x: i32 = 1;
+
+        fn doTheTest() void {
+            expect(x == 1);
+            const p = async suspendAtEnd();
+            expect(x == 2);
+        }
+
+        fn suspendAtEnd() void {
+            x += 1;
+            suspend;
+        }
+    };
+    S.doTheTest();
+}
 //test "coroutine suspend, resume" {
 //    seq('a');
 //    const p = try async<allocator> testAsyncSeq();
test/runtime_safety.zig
@@ -1,6 +1,20 @@
 const tests = @import("tests.zig");
 
 pub fn addCases(cases: *tests.CompareOutputContext) void {
+    cases.addRuntimeSafety("invalid resume of async function",
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    @import("std").os.exit(126);
+        \\}
+        \\pub fn main() void {
+        \\    var p = async suspendOnce();
+        \\    resume p; //ok
+        \\    resume p; //bad
+        \\}
+        \\fn suspendOnce() void {
+        \\    suspend;
+        \\}
+    );
+
     cases.addRuntimeSafety(".? operator on null pointer",
         \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
         \\    @import("std").os.exit(126);