Commit 1fe0142d1f

Andrew Kelley <andrew@ziglang.org>
2021-06-18 07:40:07
AstGen: properly generate errdefer expressions when returning
`return` statements use a new function `nodeMayEvalToError` which does some basic checks on the AST node to return never, always, or maybe. Depending on this result, AstGen skips the errdefers, always includes the errdefers, or emits a conditional branch to check whether the return value is an error that Sema will have to evaluate. Closes #8821 Unblocks #9047
1 parent 76102ea
Changed files (1)
src/AstGen.zig
@@ -5998,22 +5998,55 @@ fn ret(gz: *GenZir, scope: *Scope, node: ast.Node.Index) InnerError!Zir.Inst.Ref
     if (gz.in_defer) return astgen.failNode(node, "cannot return from defer expression", .{});
 
     const operand_node = node_datas[node].lhs;
-    if (operand_node != 0) {
-        const rl: ResultLoc = if (nodeMayNeedMemoryLocation(tree, operand_node)) .{
-            .ptr = try gz.addNodeExtended(.ret_ptr, node),
-        } else .{
-            .ty = try gz.addNodeExtended(.ret_type, node),
-        };
-        const operand = try expr(gz, scope, rl, operand_node);
-        // TODO check operand to see if we need to generate errdefers
+    if (operand_node == 0) {
+        // Returning a void value; skip error defers.
         try genDefers(gz, &astgen.fn_block.?.base, scope, .none);
-        _ = try gz.addUnNode(.ret_node, operand, node);
+        _ = try gz.addUnNode(.ret_node, .void_value, node);
         return Zir.Inst.Ref.unreachable_value;
     }
-    // Returning a void value; skip error defers.
-    try genDefers(gz, &astgen.fn_block.?.base, scope, .none);
-    _ = try gz.addUnNode(.ret_node, .void_value, node);
-    return Zir.Inst.Ref.unreachable_value;
+
+    const rl: ResultLoc = if (nodeMayNeedMemoryLocation(tree, operand_node)) .{
+        .ptr = try gz.addNodeExtended(.ret_ptr, node),
+    } else .{
+        .ty = try gz.addNodeExtended(.ret_type, node),
+    };
+    const operand = try expr(gz, scope, rl, operand_node);
+
+    switch (nodeMayEvalToError(tree, operand_node)) {
+        .never => {
+            // Returning a value that cannot be an error; skip error defers.
+            try genDefers(gz, &astgen.fn_block.?.base, scope, .none);
+            _ = try gz.addUnNode(.ret_node, operand, node);
+            return Zir.Inst.Ref.unreachable_value;
+        },
+        .always => {
+            // Value is always an error. Emit both error defers and regular defers.
+            const err_code = try gz.addUnNode(.err_union_code, operand, node);
+            try genDefers(gz, &astgen.fn_block.?.base, scope, err_code);
+            _ = try gz.addUnNode(.ret_node, operand, node);
+            return Zir.Inst.Ref.unreachable_value;
+        },
+        .maybe => {
+            // Emit conditional branch for generating errdefers.
+            const is_err = try gz.addUnNode(.is_err, operand, node);
+            const condbr = try gz.addCondBr(.condbr, node);
+
+            var then_scope = gz.makeSubBlock(scope);
+            defer then_scope.instructions.deinit(astgen.gpa);
+            const err_code = try then_scope.addUnNode(.err_union_code, operand, node);
+            try genDefers(&then_scope, &astgen.fn_block.?.base, scope, err_code);
+            _ = try then_scope.addUnNode(.ret_node, operand, node);
+
+            var else_scope = gz.makeSubBlock(scope);
+            defer else_scope.instructions.deinit(astgen.gpa);
+            try genDefers(&else_scope, &astgen.fn_block.?.base, scope, .none);
+            _ = try else_scope.addUnNode(.ret_node, operand, node);
+
+            try setCondBrPayload(condbr, is_err, &then_scope, &else_scope);
+
+            return Zir.Inst.Ref.unreachable_value;
+        },
+    }
 }
 
 fn identifier(
@@ -7555,6 +7588,219 @@ fn nodeMayNeedMemoryLocation(tree: *const ast.Tree, start_node: ast.Node.Index)
     }
 }
 
+fn nodeMayEvalToError(tree: *const ast.Tree, start_node: ast.Node.Index) enum { never, always, maybe } {
+    const node_tags = tree.nodes.items(.tag);
+    const node_datas = tree.nodes.items(.data);
+    const main_tokens = tree.nodes.items(.main_token);
+    const token_tags = tree.tokens.items(.tag);
+
+    var node = start_node;
+    while (true) {
+        switch (node_tags[node]) {
+            .root,
+            .@"usingnamespace",
+            .test_decl,
+            .switch_case,
+            .switch_case_one,
+            .container_field_init,
+            .container_field_align,
+            .container_field,
+            .asm_output,
+            .asm_input,
+            => unreachable,
+
+            .error_value => return .always,
+
+            .@"asm",
+            .asm_simple,
+            .identifier,
+            .field_access,
+            .deref,
+            .array_access,
+            .while_simple,
+            .while_cont,
+            .for_simple,
+            .if_simple,
+            .@"while",
+            .@"if",
+            .@"for",
+            .@"switch",
+            .switch_comma,
+            .call_one,
+            .call_one_comma,
+            .async_call_one,
+            .async_call_one_comma,
+            .call,
+            .call_comma,
+            .async_call,
+            .async_call_comma,
+            => return .maybe,
+
+            .@"return",
+            .@"break",
+            .@"continue",
+            .bit_not,
+            .bool_not,
+            .global_var_decl,
+            .local_var_decl,
+            .simple_var_decl,
+            .aligned_var_decl,
+            .@"defer",
+            .@"errdefer",
+            .address_of,
+            .optional_type,
+            .negation,
+            .negation_wrap,
+            .@"resume",
+            .array_type,
+            .array_type_sentinel,
+            .ptr_type_aligned,
+            .ptr_type_sentinel,
+            .ptr_type,
+            .ptr_type_bit_range,
+            .@"suspend",
+            .@"anytype",
+            .fn_proto_simple,
+            .fn_proto_multi,
+            .fn_proto_one,
+            .fn_proto,
+            .fn_decl,
+            .anyframe_type,
+            .anyframe_literal,
+            .integer_literal,
+            .float_literal,
+            .enum_literal,
+            .string_literal,
+            .multiline_string_literal,
+            .char_literal,
+            .true_literal,
+            .false_literal,
+            .null_literal,
+            .undefined_literal,
+            .unreachable_literal,
+            .error_set_decl,
+            .container_decl,
+            .container_decl_trailing,
+            .container_decl_two,
+            .container_decl_two_trailing,
+            .container_decl_arg,
+            .container_decl_arg_trailing,
+            .tagged_union,
+            .tagged_union_trailing,
+            .tagged_union_two,
+            .tagged_union_two_trailing,
+            .tagged_union_enum_tag,
+            .tagged_union_enum_tag_trailing,
+            .add,
+            .add_wrap,
+            .array_cat,
+            .array_mult,
+            .assign,
+            .assign_bit_and,
+            .assign_bit_or,
+            .assign_bit_shift_left,
+            .assign_bit_shift_right,
+            .assign_bit_xor,
+            .assign_div,
+            .assign_sub,
+            .assign_sub_wrap,
+            .assign_mod,
+            .assign_add,
+            .assign_add_wrap,
+            .assign_mul,
+            .assign_mul_wrap,
+            .bang_equal,
+            .bit_and,
+            .bit_or,
+            .bit_shift_left,
+            .bit_shift_right,
+            .bit_xor,
+            .bool_and,
+            .bool_or,
+            .div,
+            .equal_equal,
+            .error_union,
+            .greater_or_equal,
+            .greater_than,
+            .less_or_equal,
+            .less_than,
+            .merge_error_sets,
+            .mod,
+            .mul,
+            .mul_wrap,
+            .switch_range,
+            .sub,
+            .sub_wrap,
+            .slice,
+            .slice_open,
+            .slice_sentinel,
+            .array_init_one,
+            .array_init_one_comma,
+            .array_init_dot_two,
+            .array_init_dot_two_comma,
+            .array_init_dot,
+            .array_init_dot_comma,
+            .array_init,
+            .array_init_comma,
+            .struct_init_one,
+            .struct_init_one_comma,
+            .struct_init_dot_two,
+            .struct_init_dot_two_comma,
+            .struct_init_dot,
+            .struct_init_dot_comma,
+            .struct_init,
+            .struct_init_comma,
+            => return .never,
+
+            // Forward the question to the LHS sub-expression.
+            .grouped_expression,
+            .@"try",
+            .@"await",
+            .@"comptime",
+            .@"nosuspend",
+            .unwrap_optional,
+            => node = node_datas[node].lhs,
+
+            // Forward the question to the RHS sub-expression.
+            .@"catch",
+            .@"orelse",
+            => node = node_datas[node].rhs,
+
+            .block_two,
+            .block_two_semicolon,
+            .block,
+            .block_semicolon,
+            => {
+                const lbrace = main_tokens[node];
+                if (token_tags[lbrace - 1] == .colon) {
+                    // Labeled blocks may need a memory location to forward
+                    // to their break statements.
+                    return .maybe;
+                } else {
+                    return .never;
+                }
+            },
+
+            .builtin_call,
+            .builtin_call_comma,
+            .builtin_call_two,
+            .builtin_call_two_comma,
+            => {
+                const builtin_token = main_tokens[node];
+                const builtin_name = tree.tokenSlice(builtin_token);
+                // If the builtin is an invalid name, we don't cause an error here; instead
+                // let it pass, and the error will be "invalid builtin function" later.
+                const builtin_info = BuiltinFn.list.get(builtin_name) orelse return .maybe;
+                if (builtin_info.tag == .err_set_cast) {
+                    return .always;
+                } else {
+                    return .never;
+                }
+            },
+        }
+    }
+}
+
 /// Applies `rl` semantics to `inst`. Expressions which do not do their own handling of
 /// result locations must call this function on their result.
 /// As an example, if the `ResultLoc` is `ptr`, it will write the result to the pointer.