Commit 19ee495750

Andrew Kelley <andrew@ziglang.org>
2019-07-24 01:35:41
add error for function with ccc indirectly calling async function
1 parent 7e9760d
src/all_types.hpp
@@ -1342,7 +1342,6 @@ struct FnCall {
 };
 
 struct ZigFn {
-    CodeGen *codegen;
     LLVMValueRef llvm_value;
     const char *llvm_name;
     AstNode *proto_node;
@@ -1385,6 +1384,7 @@ struct ZigFn {
 
     AstNode *set_cold_node;
     const AstNode *inferred_async_node;
+    ZigFn *inferred_async_fn;
 
     ZigList<GlobalExport> export_list;
     ZigList<FnCall> call_list;
src/analyze.cpp
@@ -61,14 +61,14 @@ ErrorMsg *add_token_error(CodeGen *g, ZigType *owner, Token *token, Buf *msg) {
     return err;
 }
 
-ErrorMsg *add_node_error(CodeGen *g, AstNode *node, Buf *msg) {
+ErrorMsg *add_node_error(CodeGen *g, const AstNode *node, Buf *msg) {
     Token fake_token;
     fake_token.start_line = node->line;
     fake_token.start_column = node->column;
     return add_token_error(g, node->owner, &fake_token, msg);
 }
 
-ErrorMsg *add_error_note(CodeGen *g, ErrorMsg *parent_msg, AstNode *node, Buf *msg) {
+ErrorMsg *add_error_note(CodeGen *g, ErrorMsg *parent_msg, const AstNode *node, Buf *msg) {
     Token fake_token;
     fake_token.start_line = node->line;
     fake_token.start_column = node->column;
@@ -2656,7 +2656,6 @@ ZigFn *create_fn_raw(CodeGen *g, FnInline inline_value) {
 
     fn_entry->prealloc_backward_branch_quota = default_backward_branch_quota;
 
-    fn_entry->codegen = g;
     fn_entry->analyzed_executable.backward_branch_count = &fn_entry->prealloc_bbc;
     fn_entry->analyzed_executable.backward_branch_quota = &fn_entry->prealloc_backward_branch_quota;
     fn_entry->analyzed_executable.fn_entry = fn_entry;
@@ -2784,6 +2783,7 @@ static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) {
                 }
             }
         } else {
+            fn_table_entry->inferred_async_node = inferred_async_none;
             g->external_prototypes.put_unique(tld_fn->base.name, &tld_fn->base);
         }
 
@@ -2805,14 +2805,11 @@ static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) {
                 g->fn_defs.append(fn_table_entry);
         }
 
-        switch (fn_table_entry->type_entry->data.fn.fn_type_id.cc) {
-            case CallingConventionAsync:
-                fn_table_entry->inferred_async_node = fn_table_entry->proto_node;
-                break;
-            case CallingConventionUnspecified:
-                break;
-            default:
-                fn_table_entry->inferred_async_node = inferred_async_none;
+        // if the calling convention implies that it cannot be async, we save that for later
+        // and leave the value to be nullptr to indicate that we have not emitted possible
+        // compile errors for improperly calling async functions.
+        if (fn_table_entry->type_entry->data.fn.fn_type_id.cc == CallingConventionAsync) {
+            fn_table_entry->inferred_async_node = fn_table_entry->proto_node;
         }
 
         if (scope_is_root_decls(tld_fn->base.parent_scope) &&
@@ -3801,6 +3798,25 @@ bool fn_is_async(ZigFn *fn) {
     return fn->inferred_async_node != inferred_async_none;
 }
 
+static void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn) {
+    assert(fn->inferred_async_node != nullptr);
+    assert(fn->inferred_async_node != inferred_async_checking);
+    assert(fn->inferred_async_node != inferred_async_none);
+    if (fn->inferred_async_fn != nullptr) {
+        ErrorMsg *new_msg = add_error_note(g, msg, fn->inferred_async_node,
+            buf_sprintf("async function call here"));
+        return add_async_error_notes(g, new_msg, fn->inferred_async_fn);
+    } else if (fn->inferred_async_node->type == NodeTypeFnProto) {
+        add_error_note(g, msg, fn->inferred_async_node,
+            buf_sprintf("async calling convention here"));
+    } else if (fn->inferred_async_node->type == NodeTypeSuspend) {
+        add_error_note(g, msg, fn->inferred_async_node,
+            buf_sprintf("suspends here"));
+    } else {
+        zig_unreachable();
+    }
+}
+
 // This function resolves functions being inferred async.
 static void analyze_fn_async(CodeGen *g, ZigFn *fn) {
     if (fn->inferred_async_node == inferred_async_checking) {
@@ -3816,6 +3832,13 @@ static void analyze_fn_async(CodeGen *g, ZigFn *fn) {
         return;
     }
     fn->inferred_async_node = inferred_async_checking;
+
+    bool must_not_be_async = false;
+    if (fn->type_entry->data.fn.fn_type_id.cc != CallingConventionUnspecified) {
+        must_not_be_async = true;
+        fn->inferred_async_node = inferred_async_none;
+    }
+
     for (size_t i = 0; i < fn->call_list.length; i += 1) {
         FnCall *call = &fn->call_list.at(i);
         if (call->callee->type_entry->data.fn.fn_type_id.cc != CallingConventionUnspecified)
@@ -3828,6 +3851,15 @@ static void analyze_fn_async(CodeGen *g, ZigFn *fn) {
         }
         if (fn_is_async(call->callee)) {
             fn->inferred_async_node = call->source_node;
+            fn->inferred_async_fn = call->callee;
+            if (must_not_be_async) {
+                ErrorMsg *msg = add_node_error(g, fn->proto_node,
+                    buf_sprintf("function with calling convention '%s' cannot be async",
+                        calling_convention_name(fn->type_entry->data.fn.fn_type_id.cc)));
+                add_async_error_notes(g, msg, fn);
+                fn->anal_state = FnAnalStateInvalid;
+                return;
+            }
             resolve_async_fn_frame(g, fn);
             return;
         }
@@ -4451,7 +4483,7 @@ bool generic_fn_type_id_eql(GenericFnTypeId *a, GenericFnTypeId *b) {
         if (a_val->special != ConstValSpecialRuntime && b_val->special != ConstValSpecialRuntime) {
             assert(a_val->special == ConstValSpecialStatic);
             assert(b_val->special == ConstValSpecialStatic);
-            if (!const_values_equal(a->fn_entry->codegen, a_val, b_val)) {
+            if (!const_values_equal(a->codegen, a_val, b_val)) {
                 return false;
             }
         } else {
src/analyze.hpp
@@ -11,9 +11,9 @@
 #include "all_types.hpp"
 
 void semantic_analyze(CodeGen *g);
-ErrorMsg *add_node_error(CodeGen *g, AstNode *node, Buf *msg);
+ErrorMsg *add_node_error(CodeGen *g, const AstNode *node, Buf *msg);
 ErrorMsg *add_token_error(CodeGen *g, ZigType *owner, Token *token, Buf *msg);
-ErrorMsg *add_error_note(CodeGen *g, ErrorMsg *parent_msg, AstNode *node, Buf *msg);
+ErrorMsg *add_error_note(CodeGen *g, ErrorMsg *parent_msg, const AstNode *node, Buf *msg);
 void emit_error_notes_for_ref_stack(CodeGen *g, ErrorMsg *msg);
 ZigType *new_type_table_entry(ZigTypeId id);
 ZigType *get_coro_frame_type(CodeGen *g, ZigFn *fn);
test/compile_errors.zig
@@ -2,6 +2,24 @@ const tests = @import("tests.zig");
 const builtin = @import("builtin");
 
 pub fn addCases(cases: *tests.CompileErrorContext) void {
+    cases.add(
+        "function with ccc indirectly calling async function",
+        \\export fn entry() void {
+        \\    foo();
+        \\}
+        \\fn foo() void {
+        \\    bar();
+        \\}
+        \\fn bar() void {
+        \\    suspend;
+        \\}
+    ,
+        "tmp.zig:1:1: error: function with calling convention 'ccc' cannot be async",
+        "tmp.zig:2:8: note: async function call here",
+        "tmp.zig:5:8: note: async function call here",
+        "tmp.zig:8:5: note: suspends here",
+    );
+
     cases.add(
         "capture group on switch prong with incompatible payload types",
         \\const Union = union(enum) {