Commit c75d40680f

Andrew Kelley <superjoe30@gmail.com>
2016-01-07 02:02:42
while detects simple constant condition
1 parent 5f0bfca
Changed files (6)
example/guess_number/main.zig
@@ -27,6 +27,8 @@ pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
     print_u64(answer);
     print_str("\n");
 
+    return 0;
+
     /*
     while (true) {
         const line = readline("\nGuess a number between 1 and 100: ");
@@ -45,6 +47,4 @@ pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
         }
     }
     */
-
-    return 0;
 }
src/analyze.cpp
@@ -12,6 +12,8 @@
 
 static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node);
+static TypeTableEntry *eval_const_expr(CodeGen *g, BlockContext *context,
+        AstNode *node, AstNodeNumberLiteral *out_number_literal);
 
 static AstNode *first_executing_node(AstNode *node) {
     switch (node->type) {
@@ -284,6 +286,98 @@ static TypeTableEntry *get_unknown_size_array_type(CodeGen *g, ImportTableEntry
     }
 }
 
+static TypeTableEntry *eval_const_expr_bin_op(CodeGen *g, BlockContext *context,
+        AstNode *node, AstNodeNumberLiteral *out_number_literal)
+{
+    AstNodeNumberLiteral op1_lit;
+    AstNodeNumberLiteral op2_lit;
+    TypeTableEntry *op1_type = eval_const_expr(g, context, node->data.bin_op_expr.op1, &op1_lit);
+    TypeTableEntry *op2_type = eval_const_expr(g, context, node->data.bin_op_expr.op1, &op2_lit);
+
+    if (op1_type->id == TypeTableEntryIdInvalid ||
+        op2_type->id == TypeTableEntryIdInvalid)
+    {
+        return g->builtin_types.entry_invalid;
+    }
+
+    // TODO complete more of this function instead of returning invalid
+    // returning invalid makes the "unable to evaluate constant expression" error
+
+    switch (node->data.bin_op_expr.bin_op) {
+        case BinOpTypeCmpNotEq:
+            {
+                if (is_num_lit_unsigned(op1_lit.kind) &&
+                    is_num_lit_unsigned(op2_lit.kind))
+                {
+                    out_number_literal->kind = NumLitU8;
+                    out_number_literal->overflow = false;
+                    out_number_literal->data.x_uint = (op1_lit.data.x_uint != op2_lit.data.x_uint);
+                    return node->codegen_node->expr_node.type_entry;
+                } else {
+                    return g->builtin_types.entry_invalid;
+                }
+            }
+        case BinOpTypeCmpLessThan:
+            {
+                if (is_num_lit_unsigned(op1_lit.kind) &&
+                    is_num_lit_unsigned(op2_lit.kind))
+                {
+                    out_number_literal->kind = NumLitU8;
+                    out_number_literal->overflow = false;
+                    out_number_literal->data.x_uint = (op1_lit.data.x_uint < op2_lit.data.x_uint);
+                    return node->codegen_node->expr_node.type_entry;
+                } else {
+                    return g->builtin_types.entry_invalid;
+                }
+            }
+        case BinOpTypeMod:
+            {
+                if (is_num_lit_unsigned(op1_lit.kind) &&
+                    is_num_lit_unsigned(op2_lit.kind))
+                {
+                    out_number_literal->kind = NumLitU64;
+                    out_number_literal->overflow = false;
+                    out_number_literal->data.x_uint = (op1_lit.data.x_uint % op2_lit.data.x_uint);
+                    return node->codegen_node->expr_node.type_entry;
+                } else {
+                    return g->builtin_types.entry_invalid;
+                }
+            }
+        case BinOpTypeBoolOr:
+        case BinOpTypeBoolAnd:
+        case BinOpTypeCmpEq:
+        case BinOpTypeCmpGreaterThan:
+        case BinOpTypeCmpLessOrEq:
+        case BinOpTypeCmpGreaterOrEq:
+        case BinOpTypeBinOr:
+        case BinOpTypeBinXor:
+        case BinOpTypeBinAnd:
+        case BinOpTypeBitShiftLeft:
+        case BinOpTypeBitShiftRight:
+        case BinOpTypeAdd:
+        case BinOpTypeSub:
+        case BinOpTypeMult:
+        case BinOpTypeDiv:
+            return g->builtin_types.entry_invalid;
+        case BinOpTypeInvalid:
+        case BinOpTypeAssign:
+        case BinOpTypeAssignTimes:
+        case BinOpTypeAssignDiv:
+        case BinOpTypeAssignMod:
+        case BinOpTypeAssignPlus:
+        case BinOpTypeAssignMinus:
+        case BinOpTypeAssignBitShiftLeft:
+        case BinOpTypeAssignBitShiftRight:
+        case BinOpTypeAssignBitAnd:
+        case BinOpTypeAssignBitXor:
+        case BinOpTypeAssignBitOr:
+        case BinOpTypeAssignBoolAnd:
+        case BinOpTypeAssignBoolOr:
+            zig_unreachable();
+    }
+    zig_unreachable();
+}
+
 static TypeTableEntry *eval_const_expr(CodeGen *g, BlockContext *context,
         AstNode *node, AstNodeNumberLiteral *out_number_literal)
 {
@@ -291,9 +385,11 @@ static TypeTableEntry *eval_const_expr(CodeGen *g, BlockContext *context,
         case NodeTypeNumberLiteral:
             *out_number_literal = node->data.number_literal;
             return node->codegen_node->expr_node.type_entry;
+        case NodeTypeBoolLiteral:
+            out_number_literal->data.x_uint = node->data.bool_literal ? 1 : 0;
+            return node->codegen_node->expr_node.type_entry;
         case NodeTypeBinOpExpr:
-            zig_panic("TODO eval_const_expr bin op expr");
-            break;
+            return eval_const_expr_bin_op(g, context, node, out_number_literal);
         case NodeTypeCompilerFnType:
             {
                 Buf *name = &node->data.compiler_fn_type.name;
@@ -1133,8 +1229,12 @@ BlockContext *new_block_context(AstNode *node, BlockContext *parent) {
     context->variable_table.init(8);
 
     if (parent) {
-        context->break_allowed = parent->break_allowed || parent->next_child_break_allowed;
-        parent->next_child_break_allowed = false;
+        if (parent->next_child_parent_loop_node) {
+            context->parent_loop_node = parent->next_child_parent_loop_node;
+            parent->next_child_parent_loop_node = nullptr;
+        } else {
+            context->parent_loop_node = parent->parent_loop_node;
+        }
     }
 
     if (node && node->type == NodeTypeFnDef) {
@@ -1690,20 +1790,45 @@ static TypeTableEntry *analyze_struct_val_expr(CodeGen *g, ImportTableEntry *imp
 static TypeTableEntry *analyze_while_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
-    analyze_expression(g, import, context, g->builtin_types.entry_bool, node->data.while_expr.condition);
+    AstNode *condition_node = node->data.while_expr.condition;
+    AstNode *while_body_node = node->data.while_expr.body;
+    TypeTableEntry *condition_type = analyze_expression(g, import, context,
+            g->builtin_types.entry_bool, condition_node);
+
+    context->next_child_parent_loop_node = node;
+    analyze_expression(g, import, context, g->builtin_types.entry_void, while_body_node);
+
+
+    TypeTableEntry *expr_return_type = g->builtin_types.entry_void;
 
-    context->next_child_break_allowed = true;
-    analyze_expression(g, import, context, g->builtin_types.entry_void, node->data.while_expr.body);
+    if (condition_type->id == TypeTableEntryIdInvalid) {
+        expr_return_type = g->builtin_types.entry_invalid;
+    } else {
+        // if the condition is a simple constant expression and there are no break statements
+        // then the return type is unreachable
+        AstNodeNumberLiteral number_literal;
+        TypeTableEntry *resolved_type = eval_const_expr(g, context, condition_node, &number_literal);
+        if (resolved_type->id != TypeTableEntryIdInvalid) {
+            assert(resolved_type->id == TypeTableEntryIdBool);
+            bool constant_cond_value = number_literal.data.x_uint;
+            if (constant_cond_value && !node->codegen_node->data.while_node.contains_break) {
+                expr_return_type = g->builtin_types.entry_unreachable;
+            }
+        }
+    }
 
-    return g->builtin_types.entry_void;
+    return expr_return_type;
 }
 
 static TypeTableEntry *analyze_break_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
-    if (!context->break_allowed) {
+    AstNode *loop_node = context->parent_loop_node;
+    if (loop_node) {
+        loop_node->codegen_node->data.while_node.contains_break = true;
+    } else {
         add_node_error(g, node,
-                buf_sprintf("'break' expression not in loop"));
+                buf_sprintf("'break' expression outside loop"));
     }
     return g->builtin_types.entry_unreachable;
 }
@@ -1711,9 +1836,9 @@ static TypeTableEntry *analyze_break_expr(CodeGen *g, ImportTableEntry *import,
 static TypeTableEntry *analyze_continue_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
-    if (!context->break_allowed) {
+    if (!context->parent_loop_node) {
         add_node_error(g, node,
-                buf_sprintf("'continue' expression not in loop"));
+                buf_sprintf("'continue' expression outside loop"));
     }
     return g->builtin_types.entry_unreachable;
 }
src/analyze.hpp
@@ -244,8 +244,8 @@ struct BlockContext {
     HashMap<Buf *, VariableTableEntry *, buf_hash, buf_eql_buf> variable_table;
     ZigList<CastNode *> cast_expr_alloca_list;
     ZigList<StructValExprNode *> struct_val_expr_alloca_list;
-    bool break_allowed;
-    bool next_child_break_allowed;
+    AstNode *parent_loop_node;
+    AstNode *next_child_parent_loop_node;
     LLVMZigDIScope *di_scope;
 };
 
@@ -340,6 +340,10 @@ struct ImportNode {
     ImportTableEntry *import;
 };
 
+struct WhileNode {
+    bool contains_break;
+};
+
 struct CodeGenNode {
     union {
         TypeNode type_node; // for NodeTypeType
@@ -358,6 +362,7 @@ struct CodeGenNode {
         IfVarNode if_var_node; // for NodeTypeStructValueExpr
         ParamDeclNode param_decl_node; // for NodeTypeParamDecl
         ImportNode import_node; // for NodeTypeUse
+        WhileNode while_node; // for NodeTypeWhileExpr
     } data;
     ExprNode expr_node; // for all the expression nodes
 };
src/codegen.cpp
@@ -1157,29 +1157,52 @@ static LLVMValueRef gen_while_expr(CodeGen *g, AstNode *node) {
     assert(node->data.while_expr.condition);
     assert(node->data.while_expr.body);
 
-    LLVMBasicBlockRef cond_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileCond");
-    LLVMBasicBlockRef body_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileBody");
-    LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileEnd");
+    if (get_expr_type(node)->id == TypeTableEntryIdUnreachable) {
+        // generate a forever loop. guarantees no break statements
 
-    add_debug_source_node(g, node);
-    LLVMBuildBr(g->builder, cond_block);
-
-    LLVMPositionBuilderAtEnd(g->builder, cond_block);
-    LLVMValueRef cond_val = gen_expr(g, node->data.while_expr.condition);
-    add_debug_source_node(g, node->data.while_expr.condition);
-    LLVMBuildCondBr(g->builder, cond_val, body_block, end_block);
-
-    LLVMPositionBuilderAtEnd(g->builder, body_block);
-    g->break_block_stack.append(end_block);
-    g->continue_block_stack.append(cond_block);
-    gen_expr(g, node->data.while_expr.body);
-    g->break_block_stack.pop();
-    g->continue_block_stack.pop();
-    if (get_expr_type(node->data.while_expr.body)->id != TypeTableEntryIdUnreachable) {
+        LLVMBasicBlockRef body_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileBody");
+
+        add_debug_source_node(g, node);
+        LLVMBuildBr(g->builder, body_block);
+
+        LLVMPositionBuilderAtEnd(g->builder, body_block);
+        g->continue_block_stack.append(body_block);
+        gen_expr(g, node->data.while_expr.body);
+        g->continue_block_stack.pop();
+
+        if (get_expr_type(node->data.while_expr.body)->id != TypeTableEntryIdUnreachable) {
+            add_debug_source_node(g, node);
+            LLVMBuildBr(g->builder, body_block);
+        }
+    } else {
+        // generate a normal while loop
+
+        LLVMBasicBlockRef cond_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileCond");
+        LLVMBasicBlockRef body_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileBody");
+        LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileEnd");
+
+        add_debug_source_node(g, node);
         LLVMBuildBr(g->builder, cond_block);
+
+        LLVMPositionBuilderAtEnd(g->builder, cond_block);
+        LLVMValueRef cond_val = gen_expr(g, node->data.while_expr.condition);
+        add_debug_source_node(g, node->data.while_expr.condition);
+        LLVMBuildCondBr(g->builder, cond_val, body_block, end_block);
+
+        LLVMPositionBuilderAtEnd(g->builder, body_block);
+        g->break_block_stack.append(end_block);
+        g->continue_block_stack.append(cond_block);
+        gen_expr(g, node->data.while_expr.body);
+        g->break_block_stack.pop();
+        g->continue_block_stack.pop();
+        if (get_expr_type(node->data.while_expr.body)->id != TypeTableEntryIdUnreachable) {
+            add_debug_source_node(g, node);
+            LLVMBuildBr(g->builder, cond_block);
+        }
+
+        LLVMPositionBuilderAtEnd(g->builder, end_block);
     }
 
-    LLVMPositionBuilderAtEnd(g->builder, end_block);
     return nullptr;
 }
 
std/rand.zig
@@ -67,9 +67,6 @@ pub struct Rand {
                 return start + (rand_val % range);
             }
         }
-        // TODO detect simple constant in while loop and no breaks and turn it into unreachable
-        // type. then we can remove this unreachable.
-        unreachable;
     }
 
     fn generate_numbers(r: &Rand) {
test/run_tests.cpp
@@ -683,7 +683,12 @@ pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
         print_str("loop\n");
         i += 1;
     }
-    return 0;
+    return f();
+}
+fn f() -> i32 {
+    while (true) {
+        return 0;
+    }
 }
     )SOURCE", "loop\nloop\nloop\nloop\n");
 
@@ -1168,13 +1173,13 @@ fn f() {
 fn f() {
     break;
 }
-    )SOURCE", 1, ".tmp_source.zig:3:5: error: 'break' expression not in loop");
+    )SOURCE", 1, ".tmp_source.zig:3:5: error: 'break' expression outside loop");
 
     add_compile_fail_case("invalid continue expression", R"SOURCE(
 fn f() {
     continue;
 }
-    )SOURCE", 1, ".tmp_source.zig:3:5: error: 'continue' expression not in loop");
+    )SOURCE", 1, ".tmp_source.zig:3:5: error: 'continue' expression outside loop");
 
     add_compile_fail_case("invalid maybe type", R"SOURCE(
 fn f() {