Commit 20b1491e6b

Andrew Kelley <superjoe30@gmail.com>
2017-05-04 16:18:01
implement while for nullables and error unions
See #357
1 parent 698829b
src/ast_render.cpp
@@ -725,12 +725,23 @@ static void render_node_extra(AstRender *ar, AstNode *node, bool grouped) {
                 const char *inline_str = node->data.while_expr.is_inline ? "inline " : "";
                 fprintf(ar->f, "%swhile (", inline_str);
                 render_node_grouped(ar, node->data.while_expr.condition);
+                fprintf(ar->f, ") ");
+                if (node->data.while_expr.var_symbol) {
+                    fprintf(ar->f, "|%s| ", buf_ptr(node->data.while_expr.var_symbol));
+                }
                 if (node->data.while_expr.continue_expr) {
-                    fprintf(ar->f, "; ");
+                    fprintf(ar->f, ": (");
                     render_node_grouped(ar, node->data.while_expr.continue_expr);
+                    fprintf(ar->f, ") ");
                 }
-                fprintf(ar->f, ") ");
                 render_node_grouped(ar, node->data.while_expr.body);
+                if (node->data.while_expr.else_node) {
+                    fprintf(ar->f, " else ");
+                    if (node->data.while_expr.err_symbol) {
+                        fprintf(ar->f, "|%s| ", buf_ptr(node->data.while_expr.err_symbol));
+                    }
+                    render_node_grouped(ar, node->data.while_expr.else_node);
+                }
                 break;
             }
         case NodeTypeThisLiteral:
src/ir.cpp
@@ -4658,51 +4658,197 @@ static IrInstruction *ir_gen_while_expr(IrBuilder *irb, Scope *scope, AstNode *n
     assert(node->type == NodeTypeWhileExpr);
 
     AstNode *continue_expr_node = node->data.while_expr.continue_expr;
+    AstNode *else_node = node->data.while_expr.else_node;
 
     IrBasicBlock *cond_block = ir_build_basic_block(irb, scope, "WhileCond");
     IrBasicBlock *body_block = ir_build_basic_block(irb, scope, "WhileBody");
     IrBasicBlock *continue_block = continue_expr_node ?
         ir_build_basic_block(irb, scope, "WhileContinue") : cond_block;
     IrBasicBlock *end_block = ir_build_basic_block(irb, scope, "WhileEnd");
+    IrBasicBlock *else_block = else_node ?
+        ir_build_basic_block(irb, scope, "WhileElse") : end_block;
 
     IrInstruction *is_comptime = ir_build_const_bool(irb, scope, node,
         ir_should_inline(irb->exec, scope) || node->data.while_expr.is_inline);
     ir_build_br(irb, scope, node, cond_block, is_comptime);
 
-    if (continue_expr_node) {
-        ir_set_cursor_at_end(irb, continue_block);
-        IrInstruction *expr_result = ir_gen_node(irb, continue_expr_node, scope);
-        if (expr_result == irb->codegen->invalid_instruction)
-            return expr_result;
-        if (!instr_is_unreachable(expr_result))
-            ir_mark_gen(ir_build_br(irb, scope, node, cond_block, is_comptime));
-    }
+    Buf *var_symbol = node->data.while_expr.var_symbol;
+    Buf *err_symbol = node->data.while_expr.err_symbol;
+    if (err_symbol != nullptr) {
+        ir_set_cursor_at_end(irb, cond_block);
+
+        Scope *payload_scope;
+        AstNode *symbol_node = node; // TODO make more accurate
+        VariableTableEntry *payload_var;
+        if (var_symbol) {
+            // TODO make it an error to write to payload variable
+            payload_var = ir_create_var(irb, symbol_node, scope, var_symbol,
+                    true, false, false, is_comptime);
+            payload_scope = payload_var->child_scope;
+        } else {
+            payload_scope = scope;
+        }
+        IrInstruction *err_val_ptr = ir_gen_node_extra(irb, node->data.while_expr.condition, scope, LVAL_PTR);
+        if (err_val_ptr == irb->codegen->invalid_instruction)
+            return err_val_ptr;
+        IrInstruction *err_val = ir_build_load_ptr(irb, scope, node->data.while_expr.condition, err_val_ptr);
+        IrInstruction *is_err = ir_build_test_err(irb, scope, node->data.while_expr.condition, err_val);
+        if (!instr_is_unreachable(is_err)) {
+            ir_mark_gen(ir_build_cond_br(irb, scope, node->data.while_expr.condition, is_err,
+                        else_block, body_block, is_comptime));
+        }
+
+        ir_set_cursor_at_end(irb, body_block);
+        if (var_symbol) {
+            IrInstruction *var_ptr_value = ir_build_unwrap_err_payload(irb, payload_scope, symbol_node,
+                    err_val_ptr, false);
+            IrInstruction *var_value = node->data.while_expr.var_is_ptr ?
+                var_ptr_value : ir_build_load_ptr(irb, payload_scope, symbol_node, var_ptr_value);
+            ir_build_var_decl(irb, payload_scope, symbol_node, payload_var, nullptr, var_value);
+        }
+        LoopStackItem *loop_stack_item = irb->loop_stack.add_one();
+        loop_stack_item->break_block = end_block;
+        loop_stack_item->continue_block = continue_block;
+        loop_stack_item->is_comptime = is_comptime;
+        IrInstruction *body_result = ir_gen_node(irb, node->data.while_expr.body, payload_scope);
+        if (body_result == irb->codegen->invalid_instruction)
+            return body_result;
+        irb->loop_stack.pop();
+
+        if (!instr_is_unreachable(body_result))
+            ir_mark_gen(ir_build_br(irb, payload_scope, node, continue_block, is_comptime));
+
+        if (continue_expr_node) {
+            ir_set_cursor_at_end(irb, continue_block);
+            IrInstruction *expr_result = ir_gen_node(irb, continue_expr_node, payload_scope);
+            if (expr_result == irb->codegen->invalid_instruction)
+                return expr_result;
+            if (!instr_is_unreachable(expr_result))
+                ir_mark_gen(ir_build_br(irb, payload_scope, node, cond_block, is_comptime));
+        }
+
+        if (else_node) {
+            ir_set_cursor_at_end(irb, else_block);
 
-    ir_set_cursor_at_end(irb, cond_block);
-    IrInstruction *cond_val = ir_gen_node(irb, node->data.while_expr.condition, scope);
-    if (cond_val == irb->codegen->invalid_instruction)
-        return cond_val;
-    if (!instr_is_unreachable(cond_val)) {
-        ir_mark_gen(ir_build_cond_br(irb, scope, node->data.while_expr.condition, cond_val,
-                    body_block, end_block, is_comptime));
-    }
+            // TODO make it an error to write to error variable
+            AstNode *err_symbol_node = else_node; // TODO make more accurate
+            VariableTableEntry *err_var = ir_create_var(irb, err_symbol_node, scope, err_symbol,
+                    true, false, false, is_comptime);
+            Scope *err_scope = err_var->child_scope;
+            IrInstruction *err_var_value = ir_build_unwrap_err_code(irb, err_scope, err_symbol_node, err_val_ptr);
+            ir_build_var_decl(irb, err_scope, symbol_node, err_var, nullptr, err_var_value);
+
+            IrInstruction *else_result = ir_gen_node(irb, else_node, err_scope);
+            if (else_result == irb->codegen->invalid_instruction)
+                return else_result;
+            if (!instr_is_unreachable(else_result))
+                ir_mark_gen(ir_build_br(irb, scope, node, end_block, is_comptime));
+        }
+
+        ir_set_cursor_at_end(irb, end_block);
+        return ir_build_const_void(irb, scope, node);
+    } else if (var_symbol != nullptr) {
+        ir_set_cursor_at_end(irb, cond_block);
+        // TODO make it an error to write to payload variable
+        AstNode *symbol_node = node; // TODO make more accurate
+        VariableTableEntry *payload_var = ir_create_var(irb, symbol_node, scope, var_symbol,
+                true, false, false, is_comptime);
+        Scope *child_scope = payload_var->child_scope;
+        IrInstruction *maybe_val_ptr = ir_gen_node_extra(irb, node->data.while_expr.condition, scope, LVAL_PTR);
+        if (maybe_val_ptr == irb->codegen->invalid_instruction)
+            return maybe_val_ptr;
+        IrInstruction *maybe_val = ir_build_load_ptr(irb, scope, node->data.while_expr.condition, maybe_val_ptr);
+        IrInstruction *is_non_null = ir_build_test_nonnull(irb, scope, node->data.while_expr.condition, maybe_val);
+        if (!instr_is_unreachable(is_non_null)) {
+            ir_mark_gen(ir_build_cond_br(irb, scope, node->data.while_expr.condition, is_non_null,
+                        body_block, else_block, is_comptime));
+        }
+
+        ir_set_cursor_at_end(irb, body_block);
+        IrInstruction *var_ptr_value = ir_build_unwrap_maybe(irb, child_scope, symbol_node, maybe_val_ptr, false);
+        IrInstruction *var_value = node->data.while_expr.var_is_ptr ?
+            var_ptr_value : ir_build_load_ptr(irb, child_scope, symbol_node, var_ptr_value);
+        ir_build_var_decl(irb, child_scope, symbol_node, payload_var, nullptr, var_value);
+        LoopStackItem *loop_stack_item = irb->loop_stack.add_one();
+        loop_stack_item->break_block = end_block;
+        loop_stack_item->continue_block = continue_block;
+        loop_stack_item->is_comptime = is_comptime;
+        IrInstruction *body_result = ir_gen_node(irb, node->data.while_expr.body, child_scope);
+        if (body_result == irb->codegen->invalid_instruction)
+            return body_result;
+        irb->loop_stack.pop();
+
+        if (!instr_is_unreachable(body_result))
+            ir_mark_gen(ir_build_br(irb, child_scope, node, continue_block, is_comptime));
+
+        if (continue_expr_node) {
+            ir_set_cursor_at_end(irb, continue_block);
+            IrInstruction *expr_result = ir_gen_node(irb, continue_expr_node, child_scope);
+            if (expr_result == irb->codegen->invalid_instruction)
+                return expr_result;
+            if (!instr_is_unreachable(expr_result))
+                ir_mark_gen(ir_build_br(irb, child_scope, node, cond_block, is_comptime));
+        }
+
+        if (else_node) {
+            ir_set_cursor_at_end(irb, else_block);
 
-    ir_set_cursor_at_end(irb, body_block);
+            IrInstruction *else_result = ir_gen_node(irb, else_node, scope);
+            if (else_result == irb->codegen->invalid_instruction)
+                return else_result;
+            if (!instr_is_unreachable(else_result))
+                ir_mark_gen(ir_build_br(irb, scope, node, end_block, is_comptime));
+        }
 
-    LoopStackItem *loop_stack_item = irb->loop_stack.add_one();
-    loop_stack_item->break_block = end_block;
-    loop_stack_item->continue_block = continue_block;
-    loop_stack_item->is_comptime = is_comptime;
-    IrInstruction *body_result = ir_gen_node(irb, node->data.while_expr.body, scope);
-    if (body_result == irb->codegen->invalid_instruction)
-        return body_result;
-    irb->loop_stack.pop();
+        ir_set_cursor_at_end(irb, end_block);
+        return ir_build_const_void(irb, scope, node);
+    } else {
+        if (continue_expr_node) {
+            ir_set_cursor_at_end(irb, continue_block);
+            IrInstruction *expr_result = ir_gen_node(irb, continue_expr_node, scope);
+            if (expr_result == irb->codegen->invalid_instruction)
+                return expr_result;
+            if (!instr_is_unreachable(expr_result))
+                ir_mark_gen(ir_build_br(irb, scope, node, cond_block, is_comptime));
+        }
+
+        ir_set_cursor_at_end(irb, cond_block);
+        IrInstruction *cond_val = ir_gen_node(irb, node->data.while_expr.condition, scope);
+        if (cond_val == irb->codegen->invalid_instruction)
+            return cond_val;
+        if (!instr_is_unreachable(cond_val)) {
+            ir_mark_gen(ir_build_cond_br(irb, scope, node->data.while_expr.condition, cond_val,
+                        body_block, else_block, is_comptime));
+        }
+
+        ir_set_cursor_at_end(irb, body_block);
+
+        LoopStackItem *loop_stack_item = irb->loop_stack.add_one();
+        loop_stack_item->break_block = end_block;
+        loop_stack_item->continue_block = continue_block;
+        loop_stack_item->is_comptime = is_comptime;
+        IrInstruction *body_result = ir_gen_node(irb, node->data.while_expr.body, scope);
+        if (body_result == irb->codegen->invalid_instruction)
+            return body_result;
+        irb->loop_stack.pop();
+
+        if (!instr_is_unreachable(body_result))
+            ir_mark_gen(ir_build_br(irb, scope, node, continue_block, is_comptime));
+
+        if (else_node) {
+            ir_set_cursor_at_end(irb, else_block);
 
-    if (!instr_is_unreachable(body_result))
-        ir_mark_gen(ir_build_br(irb, scope, node, continue_block, is_comptime));
-    ir_set_cursor_at_end(irb, end_block);
+            IrInstruction *else_result = ir_gen_node(irb, else_node, scope);
+            if (else_result == irb->codegen->invalid_instruction)
+                return else_result;
+            if (!instr_is_unreachable(else_result))
+                ir_mark_gen(ir_build_br(irb, scope, node, end_block, is_comptime));
+        }
 
-    return ir_build_const_void(irb, scope, node);
+        ir_set_cursor_at_end(irb, end_block);
+
+        return ir_build_const_void(irb, scope, node);
+    }
 }
 
 static IrInstruction *ir_gen_for_expr(IrBuilder *irb, Scope *parent_scope, AstNode *node) {
src/parser.cpp
@@ -1654,7 +1654,7 @@ static AstNode *ast_parse_while_expr(ParseContext *pc, size_t *token_index, bool
             ast_eat_token(pc, token_index, TokenIdBinOr);
         }
 
-        node->data.while_expr.body = ast_parse_block_or_expression(pc, token_index, true);
+        node->data.while_expr.else_node = ast_parse_block_or_expression(pc, token_index, true);
     }
 
     return node;
std/linked_list.zig
@@ -187,43 +187,6 @@ pub fn LinkedList(comptime T: type) -> type {
             };
             return node;
         }
-
-        /// Iterate through the elements of the list.
-        ///
-        /// Returns:
-        ///     A list iterator with a next() method.
-        pub fn iterate(list: &List) -> List.Iterator(false) {
-            List.Iterator(false) {
-                .node = list.first,
-            }
-        }
-
-        /// Iterate through the elements of the list backwards.
-        ///
-        /// Returns:
-        ///     A list iterator with a next() method.
-        pub fn iterateBackwards(list: &List) -> List.Iterator(true) {
-            List.Iterator(true) {
-                .node = list.last,
-            }
-        }
-
-        /// Abstract iteration over a linked list.
-        pub fn Iterator(comptime backwards: bool) -> type {
-            struct {
-                const It = this;
-
-                node: ?&Node,
-
-                /// Return the next element of the list, until the end.
-                /// When no more elements are available, return null.
-                pub fn next(it: &It) -> ?&Node {
-                    const current = it.node ?? return null;
-                    it.node = if (backwards) current.prev else current.next;
-                    return current;
-                }
-            }
-        }
     }
 }
 
@@ -249,16 +212,24 @@ test "basic linked list test" {
     list.insertBefore(five, four);  // {1, 2, 4, 5}
     list.insertAfter(two, three);   // {1, 2, 3, 4, 5}
 
-    // Traverse the list forwards and backwards.
-    var it = list.iterate();
-    var it_reverse = list.iterateBackwards();
-    var index: u32 = 1;
-    while (true) {
-        const node = it.next() ?? break;
-        const node_reverse = it_reverse.next() ?? break;
-        assert (node.data == index);
-        assert (node_reverse.data == (6 - index));
-        index += 1;
+    // traverse forwards
+    {
+        var it = list.first;
+        var index: u32 = 1;
+        while (it) |node| : (it = node.next) {
+            assert(node.data == index);
+            index += 1;
+        }
+    }
+
+    // traverse backwards
+    {
+        var it = list.last;
+        var index: u32 = 1;
+        while (it) |node| : (it = node.prev) {
+            assert(node.data == (6 - index));
+            index += 1;
+        }
     }
 
     var first = list.popFirst();    // {2, 3, 4, 5}
test/cases/while.zig
@@ -1,6 +1,6 @@
 const assert = @import("std").debug.assert;
 
-test "whileLoop" {
+test "while loop" {
     var i : i32 = 0;
     while (i < 4) {
         i += 1;
@@ -16,7 +16,7 @@ fn whileLoop2() -> i32 {
         return 1;
     }
 }
-test "staticEvalWhile" {
+test "static eval while" {
     assert(static_eval_while_number == 1);
 }
 const static_eval_while_number = staticWhileLoop1();
@@ -29,7 +29,7 @@ fn staticWhileLoop2() -> i32 {
     }
 }
 
-test "continueAndBreak" {
+test "continue and break" {
     runContinueAndBreakTest();
     assert(continue_and_break_counter == 8);
 }
@@ -47,7 +47,7 @@ fn runContinueAndBreakTest() {
     assert(i == 4);
 }
 
-test "returnWithImplicitCastFromWhileLoop" {
+test "return with implicit cast from while loop" {
     %%returnWithImplicitCastFromWhileLoopTest();
 }
 fn returnWithImplicitCastFromWhileLoopTest() -> %void {
@@ -56,7 +56,7 @@ fn returnWithImplicitCastFromWhileLoopTest() -> %void {
     }
 }
 
-test "whileWithContinueExpr" {
+test "while with continue expression" {
     var sum: i32 = 0;
     {var i: i32 = 0; while (i < 10) : (i += 1) {
         if (i == 5) continue;
@@ -64,3 +64,73 @@ test "whileWithContinueExpr" {
     }}
     assert(sum == 40);
 }
+
+test "while with else" {
+    var sum: i32 = 0;
+    var i: i32 = 0;
+    var got_else: i32 = 0;
+    while (i < 10) : (i += 1) {
+        sum += 1;
+    } else {
+        got_else += 1;
+    }
+    assert(sum == 10);
+    assert(got_else == 1);
+}
+
+test "while with nullable as condition" {
+    numbers_left = 10;
+    var sum: i32 = 0;
+    while (getNumberOrNull()) |value| {
+        sum += value;
+    }
+    assert(sum == 45);
+}
+
+test "while with nullable as condition with else" {
+    numbers_left = 10;
+    var sum: i32 = 0;
+    var got_else: i32 = 0;
+    while (getNumberOrNull()) |value| {
+        sum += value;
+        assert(got_else == 0);
+    } else {
+        got_else += 1;
+    }
+    assert(sum == 45);
+    assert(got_else == 1);
+}
+
+test "while with error union condition" {
+    numbers_left = 10;
+    var sum: i32 = 0;
+    var got_else: i32 = 0;
+    while (getNumberOrErr()) |value| {
+        sum += value;
+    } else |err| {
+        assert(err == error.OutOfNumbers);
+        got_else += 1;
+    }
+    assert(sum == 45);
+    assert(got_else == 1);
+}
+
+var numbers_left: i32 = undefined;
+error OutOfNumbers;
+fn getNumberOrErr() -> %i32 {
+    return if (numbers_left == 0) {
+        error.OutOfNumbers
+    } else {
+        numbers_left -= 1;
+        numbers_left
+    };
+}
+fn getNumberOrNull() -> ?i32 {
+    return if (numbers_left == 0) {
+        null
+    } else {
+        numbers_left -= 1;
+        numbers_left
+    };
+}
+
test/compile_errors.zig
@@ -1636,4 +1636,52 @@ pub fn addCases(cases: &tests.CompileErrorContext) {
     ,
         ".tmp_source.zig:9:17: error: redefinition of 'Self'",
         ".tmp_source.zig:5:9: note: previous definition is here");
+
+    cases.add("while expected bool, got nullable",
+        \\export fn foo() {
+        \\    while (bar()) {}
+        \\}
+        \\fn bar() -> ?i32 { 1 }
+    ,
+        ".tmp_source.zig:2:15: error: expected type 'bool', found '?i32'");
+
+    cases.add("while expected bool, got error union",
+        \\export fn foo() {
+        \\    while (bar()) {}
+        \\}
+        \\fn bar() -> %i32 { 1 }
+    ,
+        ".tmp_source.zig:2:15: error: expected type 'bool', found '%i32'");
+
+    cases.add("while expected nullable, got bool",
+        \\export fn foo() {
+        \\    while (bar()) |x| {}
+        \\}
+        \\fn bar() -> bool { true }
+    ,
+        ".tmp_source.zig:2:15: error: expected nullable type, found 'bool'");
+
+    cases.add("while expected nullable, got error union",
+        \\export fn foo() {
+        \\    while (bar()) |x| {}
+        \\}
+        \\fn bar() -> %i32 { 1 }
+    ,
+        ".tmp_source.zig:2:15: error: expected nullable type, found '%i32'");
+
+    cases.add("while expected error union, got bool",
+        \\export fn foo() {
+        \\    while (bar()) |x| {} else |err| {}
+        \\}
+        \\fn bar() -> bool { true }
+    ,
+        ".tmp_source.zig:2:15: error: expected error union type, found 'bool'");
+
+    cases.add("while expected error union, got nullable",
+        \\export fn foo() {
+        \\    while (bar()) |x| {} else |err| {}
+        \\}
+        \\fn bar() -> ?i32 { 1 }
+    ,
+        ".tmp_source.zig:2:15: error: expected error union type, found '?i32'");
 }