Commit dbdc4d62d0
Changed files (5)
test
stage1
behavior
src/all_types.hpp
@@ -1726,6 +1726,7 @@ struct CodeGen {
LLVMValueRef err_name_table;
LLVMValueRef safety_crash_err_fn;
LLVMValueRef return_err_fn;
+ LLVMTypeRef async_fn_llvm_type;
// reminder: hash tables must be initialized before use
HashMap<Buf *, ZigType *, buf_hash, buf_eql_buf> import_table;
@@ -1793,7 +1794,6 @@ struct CodeGen {
ZigType *entry_global_error_set;
ZigType *entry_arg_tuple;
ZigType *entry_enum_literal;
- ZigType *entry_frame_header;
ZigType *entry_any_frame;
} builtin_types;
ZigType *align_amt_type;
src/analyze.cpp
@@ -7348,19 +7348,13 @@ static void resolve_llvm_types_fn_type(CodeGen *g, ZigType *fn_type) {
if (is_async) {
fn_type->data.fn.gen_param_info = allocate<FnGenParamInfo>(1);
- ZigType *frame_type = g->builtin_types.entry_frame_header;
- Error err;
- if ((err = type_resolve(g, frame_type, ResolveStatusSizeKnown))) {
- zig_unreachable();
- }
- ZigType *ptr_type = get_pointer_to_type(g, frame_type, false);
- gen_param_types.append(get_llvm_type(g, ptr_type));
- param_di_types.append(get_llvm_di_type(g, ptr_type));
+ ZigType *frame_type = get_any_frame_type(g, fn_type_id->return_type);
+ gen_param_types.append(get_llvm_type(g, frame_type));
+ param_di_types.append(get_llvm_di_type(g, frame_type));
fn_type->data.fn.gen_param_info[0].src_index = 0;
fn_type->data.fn.gen_param_info[0].gen_index = 0;
- fn_type->data.fn.gen_param_info[0].type = ptr_type;
-
+ fn_type->data.fn.gen_param_info[0].type = frame_type;
} else {
fn_type->data.fn.gen_param_info = allocate<FnGenParamInfo>(fn_type_id->param_count);
for (size_t i = 0; i < fn_type_id->param_count; i += 1) {
src/codegen.cpp
@@ -4902,14 +4902,28 @@ static LLVMValueRef ir_render_suspend_br(CodeGen *g, IrExecutable *executable,
return nullptr;
}
+static LLVMTypeRef async_fn_llvm_type(CodeGen *g) {
+ if (g->async_fn_llvm_type != nullptr)
+ return g->async_fn_llvm_type;
+
+ ZigType *anyframe_type = get_any_frame_type(g, nullptr);
+ LLVMTypeRef param_type = get_llvm_type(g, anyframe_type);
+ LLVMTypeRef return_type = LLVMVoidType();
+ LLVMTypeRef fn_type = LLVMFunctionType(return_type, ¶m_type, 1, false);
+ g->async_fn_llvm_type = LLVMPointerType(fn_type, 0);
+
+ return g->async_fn_llvm_type;
+}
+
static LLVMValueRef ir_render_coro_resume(CodeGen *g, IrExecutable *executable,
IrInstructionCoroResume *instruction)
{
LLVMValueRef frame = ir_llvm_value(g, instruction->frame);
ZigType *frame_type = instruction->frame->value.type;
- assert(frame_type->id == ZigTypeIdCoroFrame);
- ZigFn *fn = frame_type->data.frame.fn;
- LLVMValueRef fn_val = fn_llvm_value(g, fn);
+ assert(frame_type->id == ZigTypeIdAnyFrame);
+ LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, frame, coro_fn_ptr_index, "");
+ LLVMValueRef uncasted_fn_val = LLVMBuildLoad(g->builder, fn_ptr_ptr, "");
+ LLVMValueRef fn_val = LLVMBuildIntToPtr(g->builder, uncasted_fn_val, async_fn_llvm_type(g), "");
ZigLLVMBuildCall(g->builder, fn_val, &frame, 1, LLVMFastCallConv, ZigLLVM_FnInlineAuto, "");
return nullptr;
}
@@ -6746,11 +6760,6 @@ static void define_builtin_types(CodeGen *g) {
g->primitive_type_table.put(&entry->name, entry);
}
- {
- const char *field_names[] = {"resume_index"};
- ZigType *field_types[] = {g->builtin_types.entry_usize};
- g->builtin_types.entry_frame_header = get_struct_type(g, "(frame header)", field_names, field_types, 1);
- }
}
static BuiltinFnEntry *create_builtin_fn(CodeGen *g, BuiltinFnId id, const char *name, size_t count) {
src/ir.cpp
@@ -7764,7 +7764,7 @@ static IrInstruction *ir_gen_cancel(IrBuilder *irb, Scope *scope, AstNode *node)
static IrInstruction *ir_gen_resume(IrBuilder *irb, Scope *scope, AstNode *node) {
assert(node->type == NodeTypeResume);
- IrInstruction *target_inst = ir_gen_node(irb, node->data.resume_expr.expr, scope);
+ IrInstruction *target_inst = ir_gen_node_extra(irb, node->data.resume_expr.expr, scope, LValPtr, nullptr);
if (target_inst == irb->codegen->invalid_instruction)
return irb->codegen->invalid_instruction;
@@ -10882,6 +10882,33 @@ static IrInstruction *ir_analyze_err_set_cast(IrAnalyze *ira, IrInstruction *sou
return result;
}
+static IrInstruction *ir_analyze_frame_ptr_to_anyframe(IrAnalyze *ira, IrInstruction *source_instr,
+ IrInstruction *value, ZigType *wanted_type)
+{
+ if (instr_is_comptime(value)) {
+ zig_panic("TODO comptime frame pointer");
+ }
+
+ IrInstruction *result = ir_build_cast(&ira->new_irb, source_instr->scope, source_instr->source_node,
+ wanted_type, value, CastOpBitCast);
+ result->value.type = wanted_type;
+ return result;
+}
+
+static IrInstruction *ir_analyze_anyframe_to_anyframe(IrAnalyze *ira, IrInstruction *source_instr,
+ IrInstruction *value, ZigType *wanted_type)
+{
+ if (instr_is_comptime(value)) {
+ zig_panic("TODO comptime anyframe->T to anyframe");
+ }
+
+ IrInstruction *result = ir_build_cast(&ira->new_irb, source_instr->scope, source_instr->source_node,
+ wanted_type, value, CastOpBitCast);
+ result->value.type = wanted_type;
+ return result;
+}
+
+
static IrInstruction *ir_analyze_err_wrap_code(IrAnalyze *ira, IrInstruction *source_instr, IrInstruction *value,
ZigType *wanted_type, ResultLoc *result_loc)
{
@@ -11978,6 +12005,29 @@ static IrInstruction *ir_analyze_cast(IrAnalyze *ira, IrInstruction *source_inst
}
}
+ // *@Frame(func) to anyframe->T or anyframe
+ if (actual_type->id == ZigTypeIdPointer && actual_type->data.pointer.ptr_len == PtrLenSingle &&
+ actual_type->data.pointer.child_type->id == ZigTypeIdCoroFrame && wanted_type->id == ZigTypeIdAnyFrame)
+ {
+ bool ok = true;
+ if (wanted_type->data.any_frame.result_type != nullptr) {
+ ZigFn *fn = actual_type->data.pointer.child_type->data.frame.fn;
+ ZigType *fn_return_type = fn->type_entry->data.fn.fn_type_id.return_type;
+ if (wanted_type->data.any_frame.result_type != fn_return_type) {
+ ok = false;
+ }
+ }
+ if (ok) {
+ return ir_analyze_frame_ptr_to_anyframe(ira, source_instr, value, wanted_type);
+ }
+ }
+
+ // anyframe->T to anyframe
+ if (actual_type->id == ZigTypeIdAnyFrame && actual_type->data.any_frame.result_type != nullptr &&
+ wanted_type->id == ZigTypeIdAnyFrame && wanted_type->data.any_frame.result_type == nullptr)
+ {
+ return ir_analyze_anyframe_to_anyframe(ira, source_instr, value, wanted_type);
+ }
// cast from null literal to maybe type
if (wanted_type->id == ZigTypeIdOptional &&
@@ -24323,17 +24373,27 @@ static IrInstruction *ir_analyze_instruction_suspend_br(IrAnalyze *ira, IrInstru
}
static IrInstruction *ir_analyze_instruction_coro_resume(IrAnalyze *ira, IrInstructionCoroResume *instruction) {
- IrInstruction *frame = instruction->frame->child;
- if (type_is_invalid(frame->value.type))
+ IrInstruction *frame_ptr = instruction->frame->child;
+ if (type_is_invalid(frame_ptr->value.type))
return ira->codegen->invalid_instruction;
- if (frame->value.type->id != ZigTypeIdCoroFrame) {
- ir_add_error(ira, instruction->frame,
- buf_sprintf("expected frame, found '%s'", buf_ptr(&frame->value.type->name)));
- return ira->codegen->invalid_instruction;
+ IrInstruction *frame;
+ if (frame_ptr->value.type->id == ZigTypeIdPointer &&
+ frame_ptr->value.type->data.pointer.ptr_len == PtrLenSingle &&
+ frame_ptr->value.type->data.pointer.is_const &&
+ frame_ptr->value.type->data.pointer.child_type->id == ZigTypeIdAnyFrame)
+ {
+ frame = ir_get_deref(ira, &instruction->base, frame_ptr, nullptr);
+ } else {
+ frame = frame_ptr;
}
- return ir_build_coro_resume(&ira->new_irb, instruction->base.scope, instruction->base.source_node, frame);
+ ZigType *any_frame_type = get_any_frame_type(ira->codegen, nullptr);
+ IrInstruction *casted_frame = ir_implicit_cast(ira, frame, any_frame_type);
+ if (type_is_invalid(casted_frame->value.type))
+ return ira->codegen->invalid_instruction;
+
+ return ir_build_coro_resume(&ira->new_irb, instruction->base.scope, instruction->base.source_node, casted_frame);
}
static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction *instruction) {
test/stage1/behavior/coroutines.zig
@@ -5,15 +5,20 @@ const expect = std.testing.expect;
var global_x: i32 = 1;
test "simple coroutine suspend and resume" {
- const p = async simpleAsyncFn();
+ const frame = async simpleAsyncFn();
expect(global_x == 2);
- resume p;
+ resume frame;
expect(global_x == 3);
+ const af: anyframe->void = &frame;
+ resume frame;
+ expect(global_x == 4);
}
fn simpleAsyncFn() void {
global_x += 1;
suspend;
global_x += 1;
+ suspend;
+ global_x += 1;
}
var global_y: i32 = 1;