Commit 65a03c5859

Andrew Kelley <superjoe30@gmail.com>
2016-02-07 00:36:49
implement %defer and ?defer
see #110
1 parent 34a7e6f
src/all_types.hpp
@@ -78,8 +78,18 @@ struct ConstExprValue {
     } data;
 };
 
+enum ReturnKnowledge {
+    ReturnKnowledgeUnknown,
+    ReturnKnowledgeKnownError,
+    ReturnKnowledgeKnownNonError,
+    ReturnKnowledgeKnownNull,
+    ReturnKnowledgeKnownNonNull,
+    ReturnKnowledgeSkipDefers,
+};
+
 struct Expr {
     TypeTableEntry *type_entry;
+    ReturnKnowledge return_knowledge;
 
     LLVMValueRef const_llvm_val;
     ConstExprValue const_val;
src/analyze.cpp
@@ -3693,11 +3693,13 @@ static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, B
     // explicit cast from child type of maybe type to maybe type
     if (wanted_type->id == TypeTableEntryIdMaybe) {
         if (types_match_const_cast_only(wanted_type->data.maybe.child_type, actual_type)) {
+            get_resolved_expr(node)->return_knowledge = ReturnKnowledgeKnownNonNull;
             return resolve_cast(g, context, node, expr_node, wanted_type, CastOpMaybeWrap, true);
         } else if (actual_type->id == TypeTableEntryIdNumLitInt ||
                    actual_type->id == TypeTableEntryIdNumLitFloat)
         {
             if (num_lit_fits_in_other_type(g, expr_node, wanted_type->data.maybe.child_type)) {
+                get_resolved_expr(node)->return_knowledge = ReturnKnowledgeKnownNonNull;
                 return resolve_cast(g, context, node, expr_node, wanted_type, CastOpMaybeWrap, true);
             } else {
                 return g->builtin_types.entry_invalid;
@@ -3708,11 +3710,13 @@ static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, B
     // explicit cast from child type of error type to error type
     if (wanted_type->id == TypeTableEntryIdErrorUnion) {
         if (types_match_const_cast_only(wanted_type->data.error.child_type, actual_type)) {
+            get_resolved_expr(node)->return_knowledge = ReturnKnowledgeKnownNonError;
             return resolve_cast(g, context, node, expr_node, wanted_type, CastOpErrorWrap, true);
         } else if (actual_type->id == TypeTableEntryIdNumLitInt ||
                    actual_type->id == TypeTableEntryIdNumLitFloat)
         {
             if (num_lit_fits_in_other_type(g, expr_node, wanted_type->data.error.child_type)) {
+                get_resolved_expr(node)->return_knowledge = ReturnKnowledgeKnownNonError;
                 return resolve_cast(g, context, node, expr_node, wanted_type, CastOpErrorWrap, true);
             } else {
                 return g->builtin_types.entry_invalid;
@@ -3724,6 +3728,7 @@ static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, B
     if (wanted_type->id == TypeTableEntryIdErrorUnion &&
         actual_type->id == TypeTableEntryIdPureError)
     {
+        get_resolved_expr(node)->return_knowledge = ReturnKnowledgeKnownError;
         return resolve_cast(g, context, node, expr_node, wanted_type, CastOpPureErrorWrap, false);
     }
 
@@ -4602,44 +4607,11 @@ static TypeTableEntry *analyze_defer(CodeGen *g, ImportTableEntry *import, Block
 
     node->data.defer.child_block = new_block_context(node, parent_context);
 
-    switch (node->data.defer.kind) {
-        case ReturnKindUnconditional:
-            {
-                TypeTableEntry *resolved_type = analyze_expression(g, import, parent_context, nullptr,
-                        node->data.defer.expr);
-                validate_voided_expr(g, node->data.defer.expr, resolved_type);
+    TypeTableEntry *resolved_type = analyze_expression(g, import, parent_context, nullptr,
+            node->data.defer.expr);
+    validate_voided_expr(g, node->data.defer.expr, resolved_type);
 
-                return g->builtin_types.entry_void;
-            }
-        case ReturnKindError:
-            {
-                TypeTableEntry *resolved_type = analyze_expression(g, import, parent_context, nullptr,
-                        node->data.defer.expr);
-                if (resolved_type->id == TypeTableEntryIdInvalid) {
-                    // OK
-                } else if (resolved_type->id == TypeTableEntryIdErrorUnion) {
-                    // OK
-                } else {
-                    add_node_error(g, node->data.defer.expr,
-                            buf_sprintf("expected error type, got '%s'", buf_ptr(&resolved_type->name)));
-                }
-                return g->builtin_types.entry_void;
-            }
-        case ReturnKindMaybe:
-            {
-                TypeTableEntry *resolved_type = analyze_expression(g, import, parent_context, nullptr,
-                        node->data.defer.expr);
-                if (resolved_type->id == TypeTableEntryIdInvalid) {
-                    // OK
-                } else if (resolved_type->id == TypeTableEntryIdMaybe) {
-                    // OK
-                } else {
-                    add_node_error(g, node->data.defer.expr,
-                            buf_sprintf("expected maybe type, got '%s'", buf_ptr(&resolved_type->name)));
-                }
-                return g->builtin_types.entry_void;
-            }
-    }
+    return g->builtin_types.entry_void;
 }
 
 static TypeTableEntry *analyze_string_literal_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
src/codegen.cpp
@@ -483,6 +483,7 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) {
             }
         case CastOpPureErrorWrap:
             assert(wanted_type->id == TypeTableEntryIdErrorUnion);
+
             if (!type_has_bits(wanted_type->data.error.child_type)) {
                 return expr_val;
             } else {
@@ -1593,18 +1594,48 @@ static LLVMValueRef gen_unwrap_err_expr(CodeGen *g, AstNode *node) {
     return phi;
 }
 
-static void gen_defers_for_block(CodeGen *g, BlockContext *inner_block, BlockContext *outer_block) {
+static void gen_defers_for_block(CodeGen *g, BlockContext *inner_block, BlockContext *outer_block,
+        bool gen_error_defers, bool gen_maybe_defers)
+{
     while (inner_block != outer_block) {
-        if (inner_block->node->type == NodeTypeDefer) {
+        if (inner_block->node->type == NodeTypeDefer &&
+           ((inner_block->node->data.defer.kind == ReturnKindUnconditional) ||
+            (gen_error_defers && inner_block->node->data.defer.kind == ReturnKindError) ||
+            (gen_maybe_defers && inner_block->node->data.defer.kind == ReturnKindMaybe)))
+        {
             gen_expr(g, inner_block->node->data.defer.expr);
         }
         inner_block = inner_block->parent;
     }
 }
 
-static LLVMValueRef gen_return(CodeGen *g, AstNode *source_node, LLVMValueRef value) {
-    gen_defers_for_block(g, source_node->block_context,
-            source_node->block_context->fn_entry->fn_def_node->block_context);
+static int get_conditional_defer_count(BlockContext *inner_block, BlockContext *outer_block) {
+    int result = 0;
+    while (inner_block != outer_block) {
+        if (inner_block->node->type == NodeTypeDefer &&
+           (inner_block->node->data.defer.kind == ReturnKindError ||
+            inner_block->node->data.defer.kind == ReturnKindMaybe))
+        {
+            result += 1;
+        }
+        inner_block = inner_block->parent;
+    }
+    return result;
+}
+
+static LLVMValueRef gen_return(CodeGen *g, AstNode *source_node, LLVMValueRef value, ReturnKnowledge rk) {
+    BlockContext *defer_inner_block = source_node->block_context;
+    BlockContext *defer_outer_block = source_node->block_context->fn_entry->fn_def_node->block_context;
+    if (rk == ReturnKnowledgeUnknown) {
+        if (get_conditional_defer_count(defer_inner_block, defer_outer_block) > 0) {
+            // generate branching code that checks the return value and generates defers
+            // if the return value is error
+            zig_panic("TODO");
+        }
+    } else if (rk != ReturnKnowledgeSkipDefers) {
+        gen_defers_for_block(g, defer_inner_block, defer_outer_block,
+                rk == ReturnKnowledgeKnownError, rk == ReturnKnowledgeKnownNull);
+    }
 
     TypeTableEntry *return_type = g->cur_fn->type_entry->data.fn.fn_type_id.return_type;
     if (handle_is_ptr(return_type)) {
@@ -1628,7 +1659,23 @@ static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) {
     switch (node->data.return_expr.kind) {
         case ReturnKindUnconditional:
             {
-                return gen_return(g, node, value);
+                Expr *expr = get_resolved_expr(param_node);
+                if (expr->const_val.ok) {
+                    if (value_type->id == TypeTableEntryIdErrorUnion) {
+                        if (expr->const_val.data.x_err.err) {
+                            expr->return_knowledge = ReturnKnowledgeKnownError;
+                        } else {
+                            expr->return_knowledge = ReturnKnowledgeKnownNonError;
+                        }
+                    } else if (value_type->id == TypeTableEntryIdMaybe) {
+                        if (expr->const_val.data.x_maybe) {
+                            expr->return_knowledge = ReturnKnowledgeKnownNonNull;
+                        } else {
+                            expr->return_knowledge = ReturnKnowledgeKnownNull;
+                        }
+                    }
+                }
+                return gen_return(g, node, value, expr->return_knowledge);
             }
         case ReturnKindError:
             {
@@ -1653,7 +1700,7 @@ static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) {
                 LLVMPositionBuilderAtEnd(g->builder, return_block);
                 TypeTableEntry *return_type = g->cur_fn->type_entry->data.fn.fn_type_id.return_type;
                 if (return_type->id == TypeTableEntryIdPureError) {
-                    gen_return(g, node, err_val);
+                    gen_return(g, node, err_val, ReturnKnowledgeKnownError);
                 } else if (return_type->id == TypeTableEntryIdErrorUnion) {
                     if (type_has_bits(return_type->data.error.child_type)) {
                         assert(g->cur_ret_ptr);
@@ -1663,7 +1710,7 @@ static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) {
                         LLVMBuildStore(g->builder, err_val, tag_ptr);
                         LLVMBuildRetVoid(g->builder);
                     } else {
-                        gen_return(g, node, err_val);
+                        gen_return(g, node, err_val, ReturnKnowledgeKnownError);
                     }
                 } else {
                     zig_unreachable();
@@ -1834,10 +1881,11 @@ static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *i
         return nullptr;
     }
 
-    gen_defers_for_block(g, block_node->data.block.nested_block, block_node->data.block.child_block);
+    gen_defers_for_block(g, block_node->data.block.nested_block, block_node->data.block.child_block,
+            false, false);
 
     if (implicit_return_type) {
-        return gen_return(g, block_node, return_value);
+        return gen_return(g, block_node, return_value, ReturnKnowledgeSkipDefers);
     } else {
         return return_value;
     }
test/run_tests.cpp
@@ -1545,6 +1545,42 @@ pub fn main(args: [][]u8) -> %void {
 }
     )SOURCE", "before\ndefer2\ndefer1\n");
 
+
+    add_simple_case("%defer and it fails", R"SOURCE(
+import "std.zig";
+pub fn main(args: [][]u8) -> %void {
+    do_test() %% return;
+}
+fn do_test() -> %void {
+    %%stdout.printf("before\n");
+    defer %%stdout.printf("defer1\n");
+    %defer %%stdout.printf("deferErr\n");
+    %return its_gonna_fail();
+    defer %%stdout.printf("defer3\n");
+    %%stdout.printf("after\n");
+}
+error IToldYouItWouldFail;
+fn its_gonna_fail() -> %void {
+    return error.IToldYouItWouldFail;
+}
+    )SOURCE", "before\ndeferErr\ndefer1\n");
+
+
+    add_simple_case("%defer and it passes", R"SOURCE(
+import "std.zig";
+pub fn main(args: [][]u8) -> %void {
+    do_test() %% return;
+}
+fn do_test() -> %void {
+    %%stdout.printf("before\n");
+    defer %%stdout.printf("defer1\n");
+    %defer %%stdout.printf("deferErr\n");
+    %return its_gonna_pass();
+    defer %%stdout.printf("defer3\n");
+    %%stdout.printf("after\n");
+}
+fn its_gonna_pass() -> %void { }
+    )SOURCE", "before\nafter\ndefer3\ndefer1\n");
 }