Commit 50b70bd77f

Alexandros Naskos <alex_naskos@hotmail.com>
2020-06-24 13:07:39
@asyncCall now requires an argument tuple
1 parent 3aab601
lib/std/special/test_runner.zig
@@ -35,7 +35,7 @@ pub fn main() anyerror!void {
                     async_frame_buffer = try std.heap.page_allocator.alignedAlloc(u8, std.Target.stack_align, size);
                 }
                 const casted_fn = @ptrCast(fn () callconv(.Async) anyerror!void, test_fn.func);
-                break :blk await @asyncCall(async_frame_buffer, {}, casted_fn);
+                break :blk await @asyncCall(async_frame_buffer, {}, casted_fn, .{});
             },
             .blocking => {
                 skip_count += 1;
lib/std/dwarf.zig
@@ -359,7 +359,7 @@ fn parseFormValue(allocator: *mem.Allocator, in_stream: var, form_id: u64, endia
             const F = @TypeOf(async parseFormValue(allocator, in_stream, child_form_id, endian, is_64));
             var frame = try allocator.create(F);
             defer allocator.destroy(frame);
-            return await @asyncCall(frame, {}, parseFormValue, allocator, in_stream, child_form_id, endian, is_64);
+            return await @asyncCall(frame, {}, parseFormValue, .{ allocator, in_stream, child_form_id, endian, is_64 });
         },
         else => error.InvalidDebugInfo,
     };
lib/std/start.zig
@@ -214,7 +214,7 @@ inline fn initEventLoopAndCallMain() u8 {
 
             var result: u8 = undefined;
             var frame: @Frame(callMainAsync) = undefined;
-            _ = @asyncCall(&frame, &result, callMainAsync, loop);
+            _ = @asyncCall(&frame, &result, callMainAsync, .{loop});
             loop.run();
             return result;
         }
src/all_types.hpp
@@ -2641,6 +2641,7 @@ enum IrInstSrcId {
     IrInstSrcIdCall,
     IrInstSrcIdCallArgs,
     IrInstSrcIdCallExtra,
+    IrInstSrcIdAsyncCallExtra,
     IrInstSrcIdConst,
     IrInstSrcIdReturn,
     IrInstSrcIdContainerInitList,
@@ -3255,6 +3256,20 @@ struct IrInstSrcCallExtra {
     ResultLoc *result_loc;
 };
 
+// This is a pass1 instruction, used by @asyncCall, when the args node
+// is not a literal.
+// `args` is expected to be either a struct or a tuple.
+struct IrInstSrcAsyncCallExtra {
+    IrInstSrc base;
+
+    CallModifier modifier;
+    IrInstSrc *fn_ref;
+    IrInstSrc *ret_ptr;
+    IrInstSrc *new_stack;
+    IrInstSrc *args;
+    ResultLoc *result_loc;
+};
+
 struct IrInstGenCall {
     IrInstGen base;
 
src/ir.cpp
@@ -310,6 +310,8 @@ static void destroy_instruction_src(IrInstSrc *inst) {
             return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcCall *>(inst));
         case IrInstSrcIdCallExtra:
             return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcCallExtra *>(inst));
+        case IrInstSrcIdAsyncCallExtra:
+            return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcAsyncCallExtra *>(inst));
         case IrInstSrcIdUnOp:
             return heap::c_allocator.destroy(reinterpret_cast<IrInstSrcUnOp *>(inst));
         case IrInstSrcIdCondBr:
@@ -1173,6 +1175,10 @@ static constexpr IrInstSrcId ir_inst_id(IrInstSrcCallExtra *) {
     return IrInstSrcIdCallExtra;
 }
 
+static constexpr IrInstSrcId ir_inst_id(IrInstSrcAsyncCallExtra *) {
+    return IrInstSrcIdAsyncCallExtra;
+}
+
 static constexpr IrInstSrcId ir_inst_id(IrInstSrcConst *) {
     return IrInstSrcIdConst;
 }
@@ -2442,6 +2448,25 @@ static IrInstSrc *ir_build_call_extra(IrBuilderSrc *irb, Scope *scope, AstNode *
     return &call_instruction->base;
 }
 
+static IrInstSrc *ir_build_async_call_extra(IrBuilderSrc *irb, Scope *scope, AstNode *source_node,
+        CallModifier modifier, IrInstSrc *fn_ref, IrInstSrc *ret_ptr, IrInstSrc *new_stack, IrInstSrc *args, ResultLoc *result_loc)
+{
+    IrInstSrcAsyncCallExtra *call_instruction = ir_build_instruction<IrInstSrcAsyncCallExtra>(irb, scope, source_node);
+    call_instruction->modifier = modifier;
+    call_instruction->fn_ref = fn_ref;
+    call_instruction->ret_ptr = ret_ptr;
+    call_instruction->new_stack = new_stack;
+    call_instruction->args = args;
+    call_instruction->result_loc = result_loc;
+
+    ir_ref_instruction(fn_ref, irb->current_basic_block);
+    if (ret_ptr != nullptr) ir_ref_instruction(ret_ptr, irb->current_basic_block);
+    ir_ref_instruction(new_stack, irb->current_basic_block);
+    ir_ref_instruction(args, irb->current_basic_block);
+
+    return &call_instruction->base;
+}
+
 static IrInstSrc *ir_build_call_args(IrBuilderSrc *irb, Scope *scope, AstNode *source_node,
         IrInstSrc *options, IrInstSrc *fn_ref, IrInstSrc **args_ptr, size_t args_len,
         ResultLoc *result_loc)
@@ -6183,11 +6208,10 @@ static IrInstSrc *ir_gen_this(IrBuilderSrc *irb, Scope *orig_scope, AstNode *nod
 static IrInstSrc *ir_gen_async_call(IrBuilderSrc *irb, Scope *scope, AstNode *await_node, AstNode *call_node,
         LVal lval, ResultLoc *result_loc)
 {
-    size_t arg_offset = 3;
-    if (call_node->data.fn_call_expr.params.length < arg_offset) {
+    if (call_node->data.fn_call_expr.params.length != 4) {
         add_node_error(irb->codegen, call_node,
-            buf_sprintf("expected at least %" ZIG_PRI_usize " arguments, found %" ZIG_PRI_usize,
-                arg_offset, call_node->data.fn_call_expr.params.length));
+            buf_sprintf("expected 4 arguments, found %" ZIG_PRI_usize,
+                call_node->data.fn_call_expr.params.length));
         return irb->codegen->invalid_inst_src;
     }
 
@@ -6206,20 +6230,37 @@ static IrInstSrc *ir_gen_async_call(IrBuilderSrc *irb, Scope *scope, AstNode *aw
     if (fn_ref == irb->codegen->invalid_inst_src)
         return fn_ref;
 
-    size_t arg_count = call_node->data.fn_call_expr.params.length - arg_offset;
-    IrInstSrc **args = heap::c_allocator.allocate<IrInstSrc*>(arg_count);
-    for (size_t i = 0; i < arg_count; i += 1) {
-        AstNode *arg_node = call_node->data.fn_call_expr.params.at(i + arg_offset);
-        IrInstSrc *arg = ir_gen_node(irb, arg_node, scope);
-        if (arg == irb->codegen->invalid_inst_src)
-            return arg;
-        args[i] = arg;
-    }
-
     CallModifier modifier = (await_node == nullptr) ? CallModifierAsync : CallModifierNone;
     bool is_async_call_builtin = true;
-    IrInstSrc *call = ir_build_call_src(irb, scope, call_node, nullptr, fn_ref, arg_count, args,
-            ret_ptr, modifier, is_async_call_builtin, bytes, result_loc);
+    AstNode *args_node = call_node->data.fn_call_expr.params.at(3);
+    if (args_node->type == NodeTypeContainerInitExpr) {
+        if (args_node->data.container_init_expr.kind == ContainerInitKindArray ||
+            args_node->data.container_init_expr.entries.length == 0)
+        {
+            size_t arg_count = args_node->data.container_init_expr.entries.length;
+            IrInstSrc **args = heap::c_allocator.allocate<IrInstSrc*>(arg_count);
+            for (size_t i = 0; i < arg_count; i += 1) {
+                AstNode *arg_node = args_node->data.container_init_expr.entries.at(i);
+                IrInstSrc *arg = ir_gen_node(irb, arg_node, scope);
+                if (arg == irb->codegen->invalid_inst_src)
+                    return arg;
+                args[i] = arg;
+            }
+
+            IrInstSrc *call = ir_build_call_src(irb, scope, call_node, nullptr, fn_ref, arg_count, args,
+                ret_ptr, modifier, is_async_call_builtin, bytes, result_loc);
+            return ir_lval_wrap(irb, scope, call, lval, result_loc);
+        } else {
+            exec_add_error_node(irb->codegen, irb->exec, args_node,
+                    buf_sprintf("TODO: @asyncCall with anon struct literal"));
+            return irb->codegen->invalid_inst_src;
+        }
+    }
+    IrInstSrc *args = ir_gen_node(irb, args_node, scope);
+    if (args == irb->codegen->invalid_inst_src)
+        return args;
+
+    IrInstSrc *call = ir_build_async_call_extra(irb, scope, call_node, modifier, fn_ref, bytes, ret_ptr, args, result_loc);
     return ir_lval_wrap(irb, scope, call, lval, result_loc);
 }
 
@@ -20236,7 +20277,7 @@ static IrInstGen *ir_analyze_fn_call(IrAnalyze *ira, IrInst* source_instr,
         // Fork a scope of the function with known values for the parameters.
         Scope *parent_scope = fn_entry->fndef_scope->base.parent;
         ZigFn *impl_fn = create_fn(ira->codegen, fn_proto_node);
-        impl_fn->param_source_nodes = heap::c_allocator.allocate<AstNode *>(new_fn_arg_count);
+        
         buf_init_from_buf(&impl_fn->symbol_name, &fn_entry->symbol_name);
         impl_fn->fndef_scope = create_fndef_scope(ira->codegen, impl_fn->body_node, parent_scope, impl_fn);
         impl_fn->child_scope = &impl_fn->fndef_scope->base;
@@ -20719,40 +20760,101 @@ static IrInstGen *ir_analyze_call_extra(IrAnalyze *ira, IrInst* source_instr,
         modifier, stack, stack_src, false, args_ptr, args_len, nullptr, result_loc);
 }
 
-static IrInstGen *ir_analyze_instruction_call_extra(IrAnalyze *ira, IrInstSrcCallExtra *instruction) {
-    IrInstGen *args = instruction->args->child;
+static IrInstGen *ir_analyze_async_call_extra(IrAnalyze *ira, IrInst* source_instr, CallModifier modifier,
+        IrInstSrc *pass1_fn_ref, IrInstSrc *ret_ptr, IrInstSrc *new_stack, IrInstGen **args_ptr, size_t args_len, ResultLoc *result_loc)
+{
+    IrInstGen *fn_ref = pass1_fn_ref->child;
+    if (type_is_invalid(fn_ref->value->type))
+        return ira->codegen->invalid_inst_gen;
+
+    if (ir_should_inline(ira->old_irb.exec, source_instr->scope)) {
+        ir_add_error(ira, source_instr, buf_sprintf("TODO: comptime @asyncCall"));
+            return ira->codegen->invalid_inst_gen;
+    }
+
+    ZigFn *fn = nullptr;
+    if (instr_is_comptime(fn_ref)) {
+        if (fn_ref->value->type->id == ZigTypeIdBoundFn) {
+            assert(fn_ref->value->special == ConstValSpecialStatic);
+            fn = fn_ref->value->data.x_bound_fn.fn;
+        } else {
+            fn = ir_resolve_fn(ira, fn_ref);
+        }
+    }
+
+    IrInstGen *ret_ptr_uncasted = nullptr;
+    if (ret_ptr != nullptr) {
+        ret_ptr_uncasted = ret_ptr->child;
+        if (type_is_invalid(ret_ptr_uncasted->value->type))
+            return ira->codegen->invalid_inst_gen;
+    }
+
+    ZigType *fn_type = (fn != nullptr) ? fn->type_entry : fn_ref->value->type;
+    IrInstGen *casted_new_stack = analyze_casted_new_stack(ira, source_instr, new_stack->child,
+            &new_stack->base, true, fn);
+    if (casted_new_stack != nullptr && type_is_invalid(casted_new_stack->value->type))
+        return ira->codegen->invalid_inst_gen;
+
+    IrInstGen *result =  ir_analyze_async_call(ira, source_instr, fn, fn_type, fn_ref, args_ptr, args_len,
+        casted_new_stack, true, ret_ptr_uncasted, result_loc);
+    return ir_finish_anal(ira, result);
+}
+
+static bool ir_extract_tuple_call_args(IrAnalyze *ira, IrInst *source_instr, IrInstGen *args, IrInstGen ***args_ptr, size_t *args_len) {
     ZigType *args_type = args->value->type;
     if (type_is_invalid(args_type))
-        return ira->codegen->invalid_inst_gen;
+        return false;
 
     if (args_type->id != ZigTypeIdStruct) {
         ir_add_error(ira, &args->base,
             buf_sprintf("expected tuple or struct, found '%s'", buf_ptr(&args_type->name)));
-        return ira->codegen->invalid_inst_gen;
+        return false;
     }
 
-    IrInstGen **args_ptr = nullptr;
-    size_t args_len = 0;
-
     if (is_tuple(args_type)) {
-        args_len = args_type->data.structure.src_field_count;
-        args_ptr = heap::c_allocator.allocate<IrInstGen *>(args_len);
-        for (size_t i = 0; i < args_len; i += 1) {
+        *args_len = args_type->data.structure.src_field_count;
+        *args_ptr = heap::c_allocator.allocate<IrInstGen *>(*args_len);
+        for (size_t i = 0; i < *args_len; i += 1) {
             TypeStructField *arg_field = args_type->data.structure.fields[i];
-            args_ptr[i] = ir_analyze_struct_value_field_value(ira, &instruction->base.base, args, arg_field);
-            if (type_is_invalid(args_ptr[i]->value->type))
-                return ira->codegen->invalid_inst_gen;
+            (*args_ptr)[i] = ir_analyze_struct_value_field_value(ira, source_instr, args, arg_field);
+            if (type_is_invalid((*args_ptr)[i]->value->type))
+                return false;
         }
     } else {
         ir_add_error(ira, &args->base, buf_sprintf("TODO: struct args"));
+        return false;
+    }
+    return true;
+}
+
+static IrInstGen *ir_analyze_instruction_call_extra(IrAnalyze *ira, IrInstSrcCallExtra *instruction) {
+    IrInstGen *args = instruction->args->child;
+    IrInstGen **args_ptr = nullptr;
+    size_t args_len = 0;
+    if (!ir_extract_tuple_call_args(ira, &instruction->base.base, args, &args_ptr, &args_len)) {
         return ira->codegen->invalid_inst_gen;
     }
+
     IrInstGen *result = ir_analyze_call_extra(ira, &instruction->base.base, instruction->options,
             instruction->fn_ref, args_ptr, args_len, instruction->result_loc);
     heap::c_allocator.deallocate(args_ptr, args_len);
     return result;
 }
 
+static IrInstGen *ir_analyze_instruction_async_call_extra(IrAnalyze *ira, IrInstSrcAsyncCallExtra *instruction) {
+    IrInstGen *args = instruction->args->child;
+    IrInstGen **args_ptr = nullptr;
+    size_t args_len = 0;
+    if (!ir_extract_tuple_call_args(ira, &instruction->base.base, args, &args_ptr, &args_len)) {
+        return ira->codegen->invalid_inst_gen;
+    }
+
+    IrInstGen *result = ir_analyze_async_call_extra(ira, &instruction->base.base, instruction->modifier,
+            instruction->fn_ref, instruction->ret_ptr, instruction->new_stack, args_ptr, args_len, instruction->result_loc);
+    heap::c_allocator.deallocate(args_ptr, args_len);
+    return result;
+}
+
 static IrInstGen *ir_analyze_instruction_call_args(IrAnalyze *ira, IrInstSrcCallArgs *instruction) {
     IrInstGen **args_ptr = heap::c_allocator.allocate<IrInstGen *>(instruction->args_len);
     for (size_t i = 0; i < instruction->args_len; i += 1) {
@@ -31101,6 +31203,8 @@ static IrInstGen *ir_analyze_instruction_base(IrAnalyze *ira, IrInstSrc *instruc
             return ir_analyze_instruction_call_args(ira, (IrInstSrcCallArgs *)instruction);
         case IrInstSrcIdCallExtra:
             return ir_analyze_instruction_call_extra(ira, (IrInstSrcCallExtra *)instruction);
+        case IrInstSrcIdAsyncCallExtra:
+            return ir_analyze_instruction_async_call_extra(ira, (IrInstSrcAsyncCallExtra *)instruction);
         case IrInstSrcIdBr:
             return ir_analyze_instruction_br(ira, (IrInstSrcBr *)instruction);
         case IrInstSrcIdCondBr:
@@ -31610,6 +31714,7 @@ bool ir_inst_src_has_side_effects(IrInstSrc *instruction) {
         case IrInstSrcIdDeclVar:
         case IrInstSrcIdStorePtr:
         case IrInstSrcIdCallExtra:
+        case IrInstSrcIdAsyncCallExtra:
         case IrInstSrcIdCall:
         case IrInstSrcIdCallArgs:
         case IrInstSrcIdReturn:
src/ir_print.cpp
@@ -5,6 +5,7 @@
  * See http://opensource.org/licenses/MIT
  */
 
+#include "all_types.hpp"
 #include "analyze.hpp"
 #include "ir.hpp"
 #include "ir_print.hpp"
@@ -55,6 +56,36 @@ struct IrPrintGen {
 static void ir_print_other_inst_src(IrPrintSrc *irp, IrInstSrc *inst);
 static void ir_print_other_inst_gen(IrPrintGen *irp, IrInstGen *inst);
 
+static void ir_print_call_modifier(FILE *f, CallModifier modifier) {
+    switch (modifier) {
+        case CallModifierNone:
+            break;
+        case CallModifierNoSuspend:
+            fprintf(f, "nosuspend ");
+            break;
+        case CallModifierAsync:
+            fprintf(f, "async ");
+            break;
+        case CallModifierNeverTail:
+            fprintf(f, "notail ");
+            break;
+        case CallModifierNeverInline:
+            fprintf(f, "noinline ");
+            break;
+        case CallModifierAlwaysTail:
+            fprintf(f, "tail ");
+            break;
+        case CallModifierAlwaysInline:
+            fprintf(f, "inline ");
+            break;
+        case CallModifierCompileTime:
+            fprintf(f, "comptime ");
+            break;
+        case CallModifierBuiltin:
+            zig_unreachable();
+    }
+}
+
 const char* ir_inst_src_type_str(IrInstSrcId id) {
     switch (id) {
         case IrInstSrcIdInvalid:
@@ -97,6 +128,8 @@ const char* ir_inst_src_type_str(IrInstSrcId id) {
             return "SrcVarPtr";
         case IrInstSrcIdCallExtra:
             return "SrcCallExtra";
+        case IrInstSrcIdAsyncCallExtra:
+            return "SrcAsyncCallExtra";
         case IrInstSrcIdCall:
             return "SrcCall";
         case IrInstSrcIdCallArgs:
@@ -851,6 +884,23 @@ static void ir_print_call_extra(IrPrintSrc *irp, IrInstSrcCallExtra *instruction
     ir_print_result_loc(irp, instruction->result_loc);
 }
 
+static void ir_print_async_call_extra(IrPrintSrc *irp, IrInstSrcAsyncCallExtra *instruction) {
+    fprintf(irp->f, "modifier=");
+    ir_print_call_modifier(irp->f, instruction->modifier);
+    fprintf(irp->f, ", fn=");
+    ir_print_other_inst_src(irp, instruction->fn_ref);
+    if (instruction->ret_ptr != nullptr) {
+        fprintf(irp->f, ", ret_ptr=");
+        ir_print_other_inst_src(irp, instruction->ret_ptr);
+    }
+    fprintf(irp->f, ", new_stack=");
+    ir_print_other_inst_src(irp, instruction->new_stack);
+    fprintf(irp->f, ", args=");
+    ir_print_other_inst_src(irp, instruction->args);
+    fprintf(irp->f, ", result=");
+    ir_print_result_loc(irp, instruction->result_loc);
+}
+
 static void ir_print_call_args(IrPrintSrc *irp, IrInstSrcCallArgs *instruction) {
     fprintf(irp->f, "opts=");
     ir_print_other_inst_src(irp, instruction->options);
@@ -868,33 +918,7 @@ static void ir_print_call_args(IrPrintSrc *irp, IrInstSrcCallArgs *instruction)
 }
 
 static void ir_print_call_src(IrPrintSrc *irp, IrInstSrcCall *call_instruction) {
-    switch (call_instruction->modifier) {
-        case CallModifierNone:
-            break;
-        case CallModifierNoSuspend:
-            fprintf(irp->f, "nosuspend ");
-            break;
-        case CallModifierAsync:
-            fprintf(irp->f, "async ");
-            break;
-        case CallModifierNeverTail:
-            fprintf(irp->f, "notail ");
-            break;
-        case CallModifierNeverInline:
-            fprintf(irp->f, "noinline ");
-            break;
-        case CallModifierAlwaysTail:
-            fprintf(irp->f, "tail ");
-            break;
-        case CallModifierAlwaysInline:
-            fprintf(irp->f, "inline ");
-            break;
-        case CallModifierCompileTime:
-            fprintf(irp->f, "comptime ");
-            break;
-        case CallModifierBuiltin:
-            zig_unreachable();
-    }
+    ir_print_call_modifier(irp->f, call_instruction->modifier);
     if (call_instruction->fn_entry) {
         fprintf(irp->f, "%s", buf_ptr(&call_instruction->fn_entry->symbol_name));
     } else {
@@ -913,33 +937,7 @@ static void ir_print_call_src(IrPrintSrc *irp, IrInstSrcCall *call_instruction)
 }
 
 static void ir_print_call_gen(IrPrintGen *irp, IrInstGenCall *call_instruction) {
-    switch (call_instruction->modifier) {
-        case CallModifierNone:
-            break;
-        case CallModifierNoSuspend:
-            fprintf(irp->f, "nosuspend ");
-            break;
-        case CallModifierAsync:
-            fprintf(irp->f, "async ");
-            break;
-        case CallModifierNeverTail:
-            fprintf(irp->f, "notail ");
-            break;
-        case CallModifierNeverInline:
-            fprintf(irp->f, "noinline ");
-            break;
-        case CallModifierAlwaysTail:
-            fprintf(irp->f, "tail ");
-            break;
-        case CallModifierAlwaysInline:
-            fprintf(irp->f, "inline ");
-            break;
-        case CallModifierCompileTime:
-            fprintf(irp->f, "comptime ");
-            break;
-        case CallModifierBuiltin:
-            zig_unreachable();
-    }
+    ir_print_call_modifier(irp->f, call_instruction->modifier);
     if (call_instruction->fn_entry) {
         fprintf(irp->f, "%s", buf_ptr(&call_instruction->fn_entry->symbol_name));
     } else {
@@ -2619,6 +2617,9 @@ static void ir_print_inst_src(IrPrintSrc *irp, IrInstSrc *instruction, bool trai
         case IrInstSrcIdCallExtra:
             ir_print_call_extra(irp, (IrInstSrcCallExtra *)instruction);
             break;
+        case IrInstSrcIdAsyncCallExtra:
+            ir_print_async_call_extra(irp, (IrInstSrcAsyncCallExtra *)instruction);
+            break;
         case IrInstSrcIdCall:
             ir_print_call_src(irp, (IrInstSrcCall *)instruction);
             break;