Commit a25307c0a1

Andrew Kelley <superjoe30@gmail.com>
2016-04-20 05:28:44
add optional continue expression to while loop
closes #139
1 parent 04364c4
doc/langref.md
@@ -79,7 +79,7 @@ SwitchProng = (list(SwitchItem, ",") | "else") "=>" option("|" "Symbol" "|") Exp
 
 SwitchItem = Expression | (Expression "..." Expression)
 
-WhileExpression = "while" "(" Expression ")" Expression
+WhileExpression = "while" "(" Expression option(";" Expression) ")" Expression
 
 ForExpression = "for" "(" Expression ")" option("|" "Symbol" option("," "Symbol") "|") Expression
 
src/all_types.hpp
@@ -486,6 +486,7 @@ struct AstNodeIfVarExpr {
 
 struct AstNodeWhileExpr {
     AstNode *condition;
+    AstNode *continue_expr;
     AstNode *body;
 
     // populated by semantic analyzer
src/analyze.cpp
@@ -3507,10 +3507,15 @@ static TypeTableEntry *analyze_while_expr(CodeGen *g, ImportTableEntry *import,
 
     AstNode *condition_node = node->data.while_expr.condition;
     AstNode *while_body_node = node->data.while_expr.body;
+    AstNode **continue_expr_node = &node->data.while_expr.continue_expr;
 
     TypeTableEntry *condition_type = analyze_expression(g, import, context,
             g->builtin_types.entry_bool, condition_node);
 
+    if (*continue_expr_node) {
+        analyze_expression(g, import, context, g->builtin_types.entry_void, *continue_expr_node);
+    }
+
     BlockContext *child_context = new_block_context(node, context);
     child_context->parent_loop_node = node;
 
src/codegen.cpp
@@ -2282,12 +2282,16 @@ static LLVMValueRef gen_while_expr(CodeGen *g, AstNode *node) {
     assert(node->data.while_expr.condition);
     assert(node->data.while_expr.body);
 
+    AstNode *continue_expr_node = node->data.while_expr.continue_expr;
+
     bool condition_always_true = node->data.while_expr.condition_always_true;
     bool contains_break = node->data.while_expr.contains_break;
     if (condition_always_true) {
         // generate a forever loop
 
         LLVMBasicBlockRef body_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileBody");
+        LLVMBasicBlockRef continue_block = continue_expr_node ?
+            LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileContinue") : body_block;
         LLVMBasicBlockRef end_block = nullptr;
         if (contains_break) {
             end_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileEnd");
@@ -2296,16 +2300,25 @@ static LLVMValueRef gen_while_expr(CodeGen *g, AstNode *node) {
         add_debug_source_node(g, node);
         LLVMBuildBr(g->builder, body_block);
 
+        if (continue_expr_node) {
+            LLVMPositionBuilderAtEnd(g->builder, continue_block);
+
+            gen_expr(g, continue_expr_node);
+
+            add_debug_source_node(g, node);
+            LLVMBuildBr(g->builder, body_block);
+        }
+
         LLVMPositionBuilderAtEnd(g->builder, body_block);
         g->break_block_stack.append(end_block);
-        g->continue_block_stack.append(body_block);
+        g->continue_block_stack.append(continue_block);
         gen_expr(g, node->data.while_expr.body);
         g->break_block_stack.pop();
         g->continue_block_stack.pop();
 
         if (get_expr_type(node->data.while_expr.body)->id != TypeTableEntryIdUnreachable) {
             add_debug_source_node(g, node);
-            LLVMBuildBr(g->builder, body_block);
+            LLVMBuildBr(g->builder, continue_block);
         }
 
         if (contains_break) {
@@ -2316,11 +2329,22 @@ static LLVMValueRef gen_while_expr(CodeGen *g, AstNode *node) {
 
         LLVMBasicBlockRef cond_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileCond");
         LLVMBasicBlockRef body_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileBody");
+        LLVMBasicBlockRef continue_block = continue_expr_node ?
+            LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileContinue") : cond_block;
         LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileEnd");
 
         add_debug_source_node(g, node);
         LLVMBuildBr(g->builder, cond_block);
 
+        if (continue_expr_node) {
+            LLVMPositionBuilderAtEnd(g->builder, continue_block);
+
+            gen_expr(g, continue_expr_node);
+
+            add_debug_source_node(g, node);
+            LLVMBuildBr(g->builder, cond_block);
+        }
+
         LLVMPositionBuilderAtEnd(g->builder, cond_block);
         LLVMValueRef cond_val = gen_expr(g, node->data.while_expr.condition);
         add_debug_source_node(g, node->data.while_expr.condition);
@@ -2328,13 +2352,13 @@ static LLVMValueRef gen_while_expr(CodeGen *g, AstNode *node) {
 
         LLVMPositionBuilderAtEnd(g->builder, body_block);
         g->break_block_stack.append(end_block);
-        g->continue_block_stack.append(cond_block);
+        g->continue_block_stack.append(continue_block);
         gen_expr(g, node->data.while_expr.body);
         g->break_block_stack.pop();
         g->continue_block_stack.pop();
         if (get_expr_type(node->data.while_expr.body)->id != TypeTableEntryIdUnreachable) {
             add_debug_source_node(g, node);
-            LLVMBuildBr(g->builder, cond_block);
+            LLVMBuildBr(g->builder, continue_block);
         }
 
         LLVMPositionBuilderAtEnd(g->builder, end_block);
src/eval.cpp
@@ -1014,6 +1014,7 @@ static bool eval_while_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_val)
 
     AstNode *cond_node = node->data.while_expr.condition;
     AstNode *body_node = node->data.while_expr.body;
+    AstNode *continue_expr_node = node->data.while_expr.continue_expr;
 
     EvalScope *my_scope = allocate<EvalScope>(1);
     my_scope->block_context = body_node->block_context;
@@ -1030,6 +1031,11 @@ static bool eval_while_expr(EvalFn *ef, AstNode *node, ConstExprValue *out_val)
         ConstExprValue body_val = {0};
         if (eval_expr(ef, body_node, &body_val)) return true;
 
+        if (continue_expr_node) {
+            ConstExprValue continue_expr_val = {0};
+            if (eval_expr(ef, continue_expr_node, &continue_expr_val)) return true;
+        }
+
         ef->root->branches_used += 1;
     }
 
src/parser.cpp
@@ -1886,7 +1886,7 @@ static AstNode *ast_parse_bool_or_expr(ParseContext *pc, int *token_index, bool
 }
 
 /*
-WhileExpression : token(While) token(LParen) Expression token(RParen) Expression
+WhileExpression = "while" "(" Expression option(";" Expression) ")" Expression
 */
 static AstNode *ast_parse_while_expr(ParseContext *pc, int *token_index, bool mandatory) {
     Token *token = &pc->tokens->at(*token_index);
@@ -1904,9 +1904,20 @@ static AstNode *ast_parse_while_expr(ParseContext *pc, int *token_index, bool ma
 
     ast_eat_token(pc, token_index, TokenIdLParen);
     node->data.while_expr.condition = ast_parse_expression(pc, token_index, true);
-    ast_eat_token(pc, token_index, TokenIdRParen);
 
-    node->data.while_expr.body = ast_parse_expression(pc, token_index, true);
+    Token *semi_or_rparen = &pc->tokens->at(*token_index);
+
+    if (semi_or_rparen->id == TokenIdRParen) {
+        *token_index += 1;
+        node->data.while_expr.body = ast_parse_expression(pc, token_index, true);
+    } else if (semi_or_rparen->id == TokenIdSemicolon) {
+        *token_index += 1;
+        node->data.while_expr.continue_expr = ast_parse_expression(pc, token_index, true);
+        ast_eat_token(pc, token_index, TokenIdRParen);
+        node->data.while_expr.body = ast_parse_expression(pc, token_index, true);
+    } else {
+        ast_invalid_token_error(pc, semi_or_rparen);
+    }
 
 
     normalize_parent_ptrs(node);
std/rand.zig
@@ -107,12 +107,9 @@ pub struct Rand {
 fn test_float32() {
     var r = Rand.init(42);
 
-    // TODO for loop with range
-    var i: i32 = 0;
-    while (i < 1000) {
+    {var i: i32 = 0; while (i < 1000; i += 1) {
         const val = r.float32();
         if (!(val >= 0.0)) unreachable{};
         if (!(val < 1.0)) unreachable{};
-        i += 1;
-    }
+    }}
 }
std/str.zig
@@ -2,9 +2,7 @@ const assert = @import("index.zig").assert;
 
 pub fn len(ptr: &const u8) -> isize {
     var count: isize = 0;
-    while (ptr[count] != 0) {
-        count += 1;
-    }
+    while (ptr[count] != 0; count += 1) {}
     return count;
 }
 
test/self_hosted.zig
@@ -1258,3 +1258,14 @@ fn pub_enum_test(foo: other.APubEnum) {
 fn cast_with_imported_symbol() {
     assert(other.size_t(42) == 42);
 }
+
+
+#attribute("test")
+fn while_with_continue_expr() {
+    var sum: i32 = 0;
+    {var i: i32 = 0; while (i < 10; i += 1) {
+        if (i == 5) continue;
+        sum += i;
+    }}
+    assert(sum == 40);
+}