Commit dbdc4d62d0

Andrew Kelley <andrew@ziglang.org>
2019-07-30 01:32:49
improve support for anyframe and anyframe->T
* add implicit cast from `*@Frame(func)` to `anyframe->T` or `anyframe`. * add implicit cast from `anyframe->T` to `anyframe`. * `resume` works on `anyframe->T` and `anyframe` types.
1 parent ee64a22
Changed files (5)
src/all_types.hpp
@@ -1726,6 +1726,7 @@ struct CodeGen {
     LLVMValueRef err_name_table;
     LLVMValueRef safety_crash_err_fn;
     LLVMValueRef return_err_fn;
+    LLVMTypeRef async_fn_llvm_type;
 
     // reminder: hash tables must be initialized before use
     HashMap<Buf *, ZigType *, buf_hash, buf_eql_buf> import_table;
@@ -1793,7 +1794,6 @@ struct CodeGen {
         ZigType *entry_global_error_set;
         ZigType *entry_arg_tuple;
         ZigType *entry_enum_literal;
-        ZigType *entry_frame_header;
         ZigType *entry_any_frame;
     } builtin_types;
     ZigType *align_amt_type;
src/analyze.cpp
@@ -7348,19 +7348,13 @@ static void resolve_llvm_types_fn_type(CodeGen *g, ZigType *fn_type) {
     if (is_async) {
         fn_type->data.fn.gen_param_info = allocate<FnGenParamInfo>(1);
 
-        ZigType *frame_type = g->builtin_types.entry_frame_header;
-        Error err;
-        if ((err = type_resolve(g, frame_type, ResolveStatusSizeKnown))) {
-            zig_unreachable();
-        }
-        ZigType *ptr_type = get_pointer_to_type(g, frame_type, false);
-        gen_param_types.append(get_llvm_type(g, ptr_type));
-        param_di_types.append(get_llvm_di_type(g, ptr_type));
+        ZigType *frame_type = get_any_frame_type(g, fn_type_id->return_type);
+        gen_param_types.append(get_llvm_type(g, frame_type));
+        param_di_types.append(get_llvm_di_type(g, frame_type));
 
         fn_type->data.fn.gen_param_info[0].src_index = 0;
         fn_type->data.fn.gen_param_info[0].gen_index = 0;
-        fn_type->data.fn.gen_param_info[0].type = ptr_type;
-
+        fn_type->data.fn.gen_param_info[0].type = frame_type;
     } else {
         fn_type->data.fn.gen_param_info = allocate<FnGenParamInfo>(fn_type_id->param_count);
         for (size_t i = 0; i < fn_type_id->param_count; i += 1) {
src/codegen.cpp
@@ -4902,14 +4902,28 @@ static LLVMValueRef ir_render_suspend_br(CodeGen *g, IrExecutable *executable,
     return nullptr;
 }
 
+static LLVMTypeRef async_fn_llvm_type(CodeGen *g) {
+    if (g->async_fn_llvm_type != nullptr)
+        return g->async_fn_llvm_type;
+
+    ZigType *anyframe_type = get_any_frame_type(g, nullptr);
+    LLVMTypeRef param_type = get_llvm_type(g, anyframe_type);
+    LLVMTypeRef return_type = LLVMVoidType();
+    LLVMTypeRef fn_type = LLVMFunctionType(return_type, &param_type, 1, false);
+    g->async_fn_llvm_type = LLVMPointerType(fn_type, 0);
+
+    return g->async_fn_llvm_type;
+}
+
 static LLVMValueRef ir_render_coro_resume(CodeGen *g, IrExecutable *executable,
         IrInstructionCoroResume *instruction)
 {
     LLVMValueRef frame = ir_llvm_value(g, instruction->frame);
     ZigType *frame_type = instruction->frame->value.type;
-    assert(frame_type->id == ZigTypeIdCoroFrame);
-    ZigFn *fn = frame_type->data.frame.fn;
-    LLVMValueRef fn_val = fn_llvm_value(g, fn);
+    assert(frame_type->id == ZigTypeIdAnyFrame);
+    LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, frame, coro_fn_ptr_index, "");
+    LLVMValueRef uncasted_fn_val = LLVMBuildLoad(g->builder, fn_ptr_ptr, "");
+    LLVMValueRef fn_val = LLVMBuildIntToPtr(g->builder, uncasted_fn_val, async_fn_llvm_type(g), "");
     ZigLLVMBuildCall(g->builder, fn_val, &frame, 1, LLVMFastCallConv, ZigLLVM_FnInlineAuto, "");
     return nullptr;
 }
@@ -6746,11 +6760,6 @@ static void define_builtin_types(CodeGen *g) {
 
         g->primitive_type_table.put(&entry->name, entry);
     }
-    {
-        const char *field_names[] = {"resume_index"};
-        ZigType *field_types[] = {g->builtin_types.entry_usize};
-        g->builtin_types.entry_frame_header = get_struct_type(g, "(frame header)", field_names, field_types, 1);
-    }
 }
 
 static BuiltinFnEntry *create_builtin_fn(CodeGen *g, BuiltinFnId id, const char *name, size_t count) {
src/ir.cpp
@@ -7764,7 +7764,7 @@ static IrInstruction *ir_gen_cancel(IrBuilder *irb, Scope *scope, AstNode *node)
 static IrInstruction *ir_gen_resume(IrBuilder *irb, Scope *scope, AstNode *node) {
     assert(node->type == NodeTypeResume);
 
-    IrInstruction *target_inst = ir_gen_node(irb, node->data.resume_expr.expr, scope);
+    IrInstruction *target_inst = ir_gen_node_extra(irb, node->data.resume_expr.expr, scope, LValPtr, nullptr);
     if (target_inst == irb->codegen->invalid_instruction)
         return irb->codegen->invalid_instruction;
 
@@ -10882,6 +10882,33 @@ static IrInstruction *ir_analyze_err_set_cast(IrAnalyze *ira, IrInstruction *sou
     return result;
 }
 
+static IrInstruction *ir_analyze_frame_ptr_to_anyframe(IrAnalyze *ira, IrInstruction *source_instr,
+        IrInstruction *value, ZigType *wanted_type)
+{
+    if (instr_is_comptime(value)) {
+        zig_panic("TODO comptime frame pointer");
+    }
+
+    IrInstruction *result = ir_build_cast(&ira->new_irb, source_instr->scope, source_instr->source_node,
+            wanted_type, value, CastOpBitCast);
+    result->value.type = wanted_type;
+    return result;
+}
+
+static IrInstruction *ir_analyze_anyframe_to_anyframe(IrAnalyze *ira, IrInstruction *source_instr,
+        IrInstruction *value, ZigType *wanted_type)
+{
+    if (instr_is_comptime(value)) {
+        zig_panic("TODO comptime anyframe->T to anyframe");
+    }
+
+    IrInstruction *result = ir_build_cast(&ira->new_irb, source_instr->scope, source_instr->source_node,
+            wanted_type, value, CastOpBitCast);
+    result->value.type = wanted_type;
+    return result;
+}
+
+
 static IrInstruction *ir_analyze_err_wrap_code(IrAnalyze *ira, IrInstruction *source_instr, IrInstruction *value,
         ZigType *wanted_type, ResultLoc *result_loc)
 {
@@ -11978,6 +12005,29 @@ static IrInstruction *ir_analyze_cast(IrAnalyze *ira, IrInstruction *source_inst
         }
     }
 
+    // *@Frame(func) to anyframe->T or anyframe
+    if (actual_type->id == ZigTypeIdPointer && actual_type->data.pointer.ptr_len == PtrLenSingle &&
+        actual_type->data.pointer.child_type->id == ZigTypeIdCoroFrame && wanted_type->id == ZigTypeIdAnyFrame)
+    {
+        bool ok = true;
+        if (wanted_type->data.any_frame.result_type != nullptr) {
+            ZigFn *fn = actual_type->data.pointer.child_type->data.frame.fn;
+            ZigType *fn_return_type = fn->type_entry->data.fn.fn_type_id.return_type;
+            if (wanted_type->data.any_frame.result_type != fn_return_type) {
+                ok = false;
+            }
+        }
+        if (ok) {
+            return ir_analyze_frame_ptr_to_anyframe(ira, source_instr, value, wanted_type);
+        }
+    }
+
+    // anyframe->T to anyframe
+    if (actual_type->id == ZigTypeIdAnyFrame && actual_type->data.any_frame.result_type != nullptr &&
+        wanted_type->id == ZigTypeIdAnyFrame && wanted_type->data.any_frame.result_type == nullptr)
+    {
+        return ir_analyze_anyframe_to_anyframe(ira, source_instr, value, wanted_type);
+    }
 
     // cast from null literal to maybe type
     if (wanted_type->id == ZigTypeIdOptional &&
@@ -24323,17 +24373,27 @@ static IrInstruction *ir_analyze_instruction_suspend_br(IrAnalyze *ira, IrInstru
 }
 
 static IrInstruction *ir_analyze_instruction_coro_resume(IrAnalyze *ira, IrInstructionCoroResume *instruction) {
-    IrInstruction *frame = instruction->frame->child;
-    if (type_is_invalid(frame->value.type))
+    IrInstruction *frame_ptr = instruction->frame->child;
+    if (type_is_invalid(frame_ptr->value.type))
         return ira->codegen->invalid_instruction;
 
-    if (frame->value.type->id != ZigTypeIdCoroFrame) {
-        ir_add_error(ira, instruction->frame,
-            buf_sprintf("expected frame, found '%s'", buf_ptr(&frame->value.type->name)));
-        return ira->codegen->invalid_instruction;
+    IrInstruction *frame;
+    if (frame_ptr->value.type->id == ZigTypeIdPointer &&
+        frame_ptr->value.type->data.pointer.ptr_len == PtrLenSingle &&
+        frame_ptr->value.type->data.pointer.is_const &&
+        frame_ptr->value.type->data.pointer.child_type->id == ZigTypeIdAnyFrame)
+    {
+        frame = ir_get_deref(ira, &instruction->base, frame_ptr, nullptr);
+    } else {
+        frame = frame_ptr;
     }
 
-    return ir_build_coro_resume(&ira->new_irb, instruction->base.scope, instruction->base.source_node, frame);
+    ZigType *any_frame_type = get_any_frame_type(ira->codegen, nullptr);
+    IrInstruction *casted_frame = ir_implicit_cast(ira, frame, any_frame_type);
+    if (type_is_invalid(casted_frame->value.type))
+        return ira->codegen->invalid_instruction;
+
+    return ir_build_coro_resume(&ira->new_irb, instruction->base.scope, instruction->base.source_node, casted_frame);
 }
 
 static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction *instruction) {
test/stage1/behavior/coroutines.zig
@@ -5,15 +5,20 @@ const expect = std.testing.expect;
 var global_x: i32 = 1;
 
 test "simple coroutine suspend and resume" {
-    const p = async simpleAsyncFn();
+    const frame = async simpleAsyncFn();
     expect(global_x == 2);
-    resume p;
+    resume frame;
     expect(global_x == 3);
+    const af: anyframe->void = &frame;
+    resume frame;
+    expect(global_x == 4);
 }
 fn simpleAsyncFn() void {
     global_x += 1;
     suspend;
     global_x += 1;
+    suspend;
+    global_x += 1;
 }
 
 var global_y: i32 = 1;