Commit 6ab8b2aab4

Andrew Kelley <andrew@ziglang.org>
2019-08-31 02:06:02
support recursive async and non-async functions
which heap allocate their own frames related: #1006
1 parent 2148943
src/all_types.hpp
@@ -627,7 +627,7 @@ struct AstNodeParamDecl {
     AstNode *type;
     Token *var_token;
     bool is_noalias;
-    bool is_inline;
+    bool is_comptime;
     bool is_var_args;
 };
 
src/analyze.cpp
@@ -1556,7 +1556,7 @@ static ZigType *analyze_fn_type(CodeGen *g, AstNode *proto_node, Scope *child_sc
         AstNode *param_node = fn_proto->params.at(fn_type_id.next_param_index);
         assert(param_node->type == NodeTypeParamDecl);
 
-        bool param_is_comptime = param_node->data.param_decl.is_inline;
+        bool param_is_comptime = param_node->data.param_decl.is_comptime;
         bool param_is_var_args = param_node->data.param_decl.is_var_args;
 
         if (param_is_comptime) {
@@ -8234,6 +8234,10 @@ static void resolve_llvm_types_anyerror(CodeGen *g) {
 }
 
 static void resolve_llvm_types_async_frame(CodeGen *g, ZigType *frame_type, ResolveStatus wanted_resolve_status) {
+    Error err;
+    if ((err = type_resolve(g, frame_type, ResolveStatusSizeKnown)))
+        zig_unreachable();
+
     ZigType *passed_frame_type = fn_is_async(frame_type->data.frame.fn) ? frame_type : nullptr;
     resolve_llvm_types_struct(g, frame_type->data.frame.locals_struct, wanted_resolve_status, passed_frame_type);
     frame_type->llvm_type = frame_type->data.frame.locals_struct->llvm_type;
@@ -8375,7 +8379,6 @@ static void resolve_llvm_types_any_frame(CodeGen *g, ZigType *any_frame_type, Re
 }
 
 static void resolve_llvm_types(CodeGen *g, ZigType *type, ResolveStatus wanted_resolve_status) {
-    assert(type->id == ZigTypeIdOpaque || type_is_resolved(type, ResolveStatusSizeKnown));
     assert(wanted_resolve_status > ResolveStatusSizeKnown);
     switch (type->id) {
         case ZigTypeIdInvalid:
src/ast_render.cpp
@@ -448,7 +448,7 @@ static void render_node_extra(AstRender *ar, AstNode *node, bool grouped) {
                     assert(param_decl->type == NodeTypeParamDecl);
                     if (param_decl->data.param_decl.name != nullptr) {
                         const char *noalias_str = param_decl->data.param_decl.is_noalias ? "noalias " : "";
-                        const char *inline_str = param_decl->data.param_decl.is_inline ? "inline " : "";
+                        const char *inline_str = param_decl->data.param_decl.is_comptime ? "comptime " : "";
                         fprintf(ar->f, "%s%s", noalias_str, inline_str);
                         print_symbol(ar, param_decl->data.param_decl.name);
                         fprintf(ar->f, ": ");
src/codegen.cpp
@@ -6340,9 +6340,12 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val, const c
     ZigType *type_entry = const_val->type;
     assert(type_has_bits(type_entry));
 
-    switch (const_val->special) {
+check: switch (const_val->special) {
         case ConstValSpecialLazy:
-            zig_unreachable();
+            if ((err = ir_resolve_lazy(g, nullptr, const_val))) {
+                report_errors_and_exit(g);
+            }
+            goto check;
         case ConstValSpecialRuntime:
             zig_unreachable();
         case ConstValSpecialUndef:
src/ir.cpp
@@ -9012,7 +9012,42 @@ static bool ir_num_lit_fits_in_other_type(IrAnalyze *ira, IrInstruction *instruc
         return false;
     }
 
-    ConstExprValue *const_val = ir_resolve_const(ira, instruction, UndefBad);
+    ConstExprValue *const_val = ir_resolve_const(ira, instruction, LazyOkNoUndef);
+    if (const_val == nullptr)
+        return false;
+
+    if (const_val->special == ConstValSpecialLazy) {
+        switch (const_val->data.x_lazy->id) {
+            case LazyValueIdAlignOf: {
+                // This is guaranteed to fit into a u29
+                if (other_type->id == ZigTypeIdComptimeInt)
+                    return true;
+                size_t align_bits = get_align_amt_type(ira->codegen)->data.integral.bit_count;
+                if (other_type->id == ZigTypeIdInt && !other_type->data.integral.is_signed &&
+                    other_type->data.integral.bit_count >= align_bits)
+                {
+                    return true;
+                }
+                break;
+            }
+            case LazyValueIdSizeOf: {
+                // This is guaranteed to fit into a usize
+                if (other_type->id == ZigTypeIdComptimeInt)
+                    return true;
+                size_t usize_bits = ira->codegen->builtin_types.entry_usize->data.integral.bit_count;
+                if (other_type->id == ZigTypeIdInt && !other_type->data.integral.is_signed &&
+                    other_type->data.integral.bit_count >= usize_bits)
+                {
+                    return true;
+                }
+                break;
+            }
+            default:
+                break;
+        }
+    }
+
+    const_val = ir_resolve_const(ira, instruction, UndefBad);
     if (const_val == nullptr)
         return false;
 
@@ -10262,7 +10297,7 @@ static void copy_const_val(ConstExprValue *dest, ConstExprValue *src, bool same_
     memcpy(dest, src, sizeof(ConstExprValue));
     if (!same_global_refs) {
         dest->global_refs = global_refs;
-        if (src->special == ConstValSpecialUndef)
+        if (src->special != ConstValSpecialStatic)
             return;
         if (dest->type->id == ZigTypeIdStruct) {
             dest->data.x_struct.fields = create_const_vals(dest->type->data.structure.src_field_count);
@@ -11213,7 +11248,7 @@ static IrInstruction *ir_get_ref(IrAnalyze *ira, IrInstruction *source_instructi
         return ira->codegen->invalid_instruction;
 
     if (instr_is_comptime(value)) {
-        ConstExprValue *val = ir_resolve_const(ira, value, UndefOk);
+        ConstExprValue *val = ir_resolve_const(ira, value, LazyOk);
         if (!val)
             return ira->codegen->invalid_instruction;
         return ir_get_const_ptr(ira, source_instruction, val, value->value.type,
@@ -12125,7 +12160,8 @@ static IrInstruction *ir_analyze_cast(IrAnalyze *ira, IrInstruction *source_inst
             if (wanted_type->id == ZigTypeIdComptimeInt || wanted_type->id == ZigTypeIdInt) {
                 IrInstruction *result = ir_const(ira, source_instr, wanted_type);
                 if (actual_type->id == ZigTypeIdComptimeInt || actual_type->id == ZigTypeIdInt) {
-                    bigint_init_bigint(&result->value.data.x_bigint, &value->value.data.x_bigint);
+                    copy_const_val(&result->value, &value->value, false);
+                    result->value.type = wanted_type;
                 } else {
                     float_init_bigint(&result->value.data.x_bigint, &value->value);
                 }
@@ -15301,7 +15337,7 @@ static bool ir_analyze_fn_call_generic_arg(IrAnalyze *ira, AstNode *fn_proto_nod
         }
     }
 
-    bool comptime_arg = param_decl_node->data.param_decl.is_inline ||
+    bool comptime_arg = param_decl_node->data.param_decl.is_comptime ||
         casted_arg->value.type->id == ZigTypeIdComptimeInt || casted_arg->value.type->id == ZigTypeIdComptimeFloat;
 
     ConstExprValue *arg_val;
@@ -17594,6 +17630,11 @@ static IrInstruction *ir_analyze_instruction_field_ptr(IrAnalyze *ira, IrInstruc
         ConstExprValue *child_val = const_ptr_pointee(ira, ira->codegen, container_ptr_val, source_node);
         if (child_val == nullptr)
             return ira->codegen->invalid_instruction;
+        if ((err = ir_resolve_const_val(ira->codegen, ira->new_irb.exec,
+            field_ptr_instruction->base.source_node, child_val, UndefBad)))
+        {
+            return ira->codegen->invalid_instruction;
+        }
         ZigType *child_type = child_val->data.x_type;
 
         if (type_is_invalid(child_type)) {
@@ -21293,8 +21334,10 @@ static IrInstruction *ir_analyze_instruction_from_bytes(IrAnalyze *ira, IrInstru
         src_ptr_align = get_abi_alignment(ira->codegen, target->value.type);
     }
 
-    if ((err = type_resolve(ira->codegen, dest_child_type, ResolveStatusSizeKnown)))
-        return ira->codegen->invalid_instruction;
+    if (src_ptr_align != 0) {
+        if ((err = type_resolve(ira->codegen, dest_child_type, ResolveStatusAlignmentKnown)))
+            return ira->codegen->invalid_instruction;
+    }
 
     ZigType *dest_ptr_type = get_pointer_to_type_extra(ira->codegen, dest_child_type,
             src_ptr_const, src_ptr_volatile, PtrLenUnknown,
@@ -21337,6 +21380,8 @@ static IrInstruction *ir_analyze_instruction_from_bytes(IrAnalyze *ira, IrInstru
     }
 
     if (have_known_len) {
+        if ((err = type_resolve(ira->codegen, dest_child_type, ResolveStatusSizeKnown)))
+            return ira->codegen->invalid_instruction;
         uint64_t child_type_size = type_size(ira->codegen, dest_child_type);
         uint64_t remainder = known_len % child_type_size;
         if (remainder != 0) {
@@ -23963,15 +24008,23 @@ static IrInstruction *ir_analyze_instruction_ptr_type(IrAnalyze *ira, IrInstruct
 }
 
 static IrInstruction *ir_analyze_instruction_align_cast(IrAnalyze *ira, IrInstructionAlignCast *instruction) {
-    uint32_t align_bytes;
-    IrInstruction *align_bytes_inst = instruction->align_bytes->child;
-    if (!ir_resolve_align(ira, align_bytes_inst, nullptr, &align_bytes))
-        return ira->codegen->invalid_instruction;
-
     IrInstruction *target = instruction->target->child;
     if (type_is_invalid(target->value.type))
         return ira->codegen->invalid_instruction;
 
+    ZigType *elem_type = nullptr;
+    if (is_slice(target->value.type)) {
+        ZigType *slice_ptr_type = target->value.type->data.structure.fields[slice_ptr_index].type_entry;
+        elem_type = slice_ptr_type->data.pointer.child_type;
+    } else if (target->value.type->id == ZigTypeIdPointer) {
+        elem_type = target->value.type->data.pointer.child_type;
+    }
+
+    uint32_t align_bytes;
+    IrInstruction *align_bytes_inst = instruction->align_bytes->child;
+    if (!ir_resolve_align(ira, align_bytes_inst, elem_type, &align_bytes))
+        return ira->codegen->invalid_instruction;
+
     IrInstruction *result = ir_align_cast(ira, target, align_bytes, true);
     if (type_is_invalid(result->value.type))
         return ira->codegen->invalid_instruction;
@@ -25644,7 +25697,7 @@ static Error ir_resolve_lazy_raw(AstNode *source_node, ConstExprValue *val) {
             }
 
             val->special = ConstValSpecialStatic;
-            assert(val->type->id == ZigTypeIdComptimeInt);
+            assert(val->type->id == ZigTypeIdComptimeInt || val->type->id == ZigTypeIdInt);
             bigint_init_unsigned(&val->data.x_bigint, align_in_bytes);
             return ErrorNone;
         }
@@ -25699,7 +25752,7 @@ static Error ir_resolve_lazy_raw(AstNode *source_node, ConstExprValue *val) {
             }
 
             val->special = ConstValSpecialStatic;
-            assert(val->type->id == ZigTypeIdComptimeInt);
+            assert(val->type->id == ZigTypeIdComptimeInt || val->type->id == ZigTypeIdInt);
             bigint_init_unsigned(&val->data.x_bigint, abi_size);
             return ErrorNone;
         }
@@ -25885,7 +25938,7 @@ static Error ir_resolve_lazy_raw(AstNode *source_node, ConstExprValue *val) {
 Error ir_resolve_lazy(CodeGen *codegen, AstNode *source_node, ConstExprValue *val) {
     Error err;
     if ((err = ir_resolve_lazy_raw(source_node, val))) {
-        if (codegen->trace_err != nullptr && !source_node->already_traced_this_node) {
+        if (codegen->trace_err != nullptr && source_node != nullptr && !source_node->already_traced_this_node) {
             source_node->already_traced_this_node = true;
             codegen->trace_err = add_error_note(codegen, codegen->trace_err, source_node,
                 buf_create_from_str("referenced here"));
src/parser.cpp
@@ -2075,7 +2075,7 @@ static AstNode *ast_parse_param_decl(ParseContext *pc) {
     res->column = first->start_column;
     res->data.param_decl.name = token_buf(name);
     res->data.param_decl.is_noalias = first->id == TokenIdKeywordNoAlias;
-    res->data.param_decl.is_inline = first->id == TokenIdKeywordCompTime;
+    res->data.param_decl.is_comptime = first->id == TokenIdKeywordCompTime;
     return res;
 }
 
std/mem.zig
@@ -117,7 +117,15 @@ pub const Allocator = struct {
         const byte_slice = try self.reallocFn(self, ([*]u8)(undefined)[0..0], undefined, byte_count, a);
         assert(byte_slice.len == byte_count);
         @memset(byte_slice.ptr, undefined, byte_slice.len);
-        return @bytesToSlice(T, @alignCast(a, byte_slice));
+        if (alignment == null) {
+            // TODO This is a workaround for zig not being able to successfully do
+            // @bytesToSlice(T, @alignCast(a, byte_slice)) without resolving alignment of T,
+            // which causes a circular dependency in async functions which try to heap-allocate
+            // their own frame with @Frame(func).
+            return @intToPtr([*]T, @ptrToInt(byte_slice.ptr))[0..n];
+        } else {
+            return @bytesToSlice(T, @alignCast(a, byte_slice));
+        }
     }
 
     /// This function requests a new byte size for an existing allocation,
test/stage1/behavior/async_fn.zig
@@ -854,3 +854,68 @@ test "await does not force async if callee is blocking" {
     var x = async S.simple();
     expect(await x == 1234);
 }
+
+test "recursive async function" {
+    expect(recursiveAsyncFunctionTest(false).doTheTest() == 55);
+    expect(recursiveAsyncFunctionTest(true).doTheTest() == 55);
+}
+
+fn recursiveAsyncFunctionTest(comptime suspending_implementation: bool) type {
+    return struct {
+        fn fib(allocator: *std.mem.Allocator, x: u32) error{OutOfMemory}!u32 {
+            if (x <= 1) return x;
+
+            if (suspending_implementation) {
+                suspend {
+                    resume @frame();
+                }
+            }
+
+            const f1 = try allocator.create(@Frame(fib));
+            defer allocator.destroy(f1);
+
+            const f2 = try allocator.create(@Frame(fib));
+            defer allocator.destroy(f2);
+
+            f1.* = async fib(allocator, x - 1);
+            var f1_awaited = false;
+            errdefer if (!f1_awaited) {
+                _ = await f1;
+            };
+
+            f2.* = async fib(allocator, x - 2);
+            var f2_awaited = false;
+            errdefer if (!f2_awaited) {
+                _ = await f2;
+            };
+
+            var sum: u32 = 0;
+
+            f1_awaited = true;
+            const result_f1 = await f1; // TODO https://github.com/ziglang/zig/issues/3077
+            sum += try result_f1;
+
+            f2_awaited = true;
+            const result_f2 = await f2; // TODO https://github.com/ziglang/zig/issues/3077
+            sum += try result_f2;
+
+            return sum;
+        }
+
+        fn doTheTest() u32 {
+            if (suspending_implementation) {
+                var result: u32 = undefined;
+                _ = async amain(&result);
+                return result;
+            } else {
+                return fib(std.heap.direct_allocator, 10) catch unreachable;
+            }
+        }
+
+        fn amain(result: *u32) void {
+            var x = async fib(std.heap.direct_allocator, 10);
+            const res = await x; // TODO https://github.com/ziglang/zig/issues/3077
+            result.* = res catch unreachable;
+        }
+    };
+}
test/compile_errors.zig
@@ -1051,6 +1051,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
         \\const Foo = struct {};
         \\export fn a() void {
         \\    const T = [*c]Foo;
+        \\    var t: T = undefined;
         \\}
     ,
         "tmp.zig:3:19: error: C pointers cannot point to non-C-ABI-compatible type 'Foo'",
@@ -2290,6 +2291,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
         "error union operator with non error set LHS",
         \\comptime {
         \\    const z = i32!i32;
+        \\    var x: z = undefined;
         \\}
     ,
         "tmp.zig:2:15: error: expected error set type, found type 'i32'",