Commit faa44e2e58

Andrew Kelley <andrew@ziglang.org>
2023-02-17 19:51:22
AstGen: rework multi-object for loop
* Allow unbounded looping. * Lower by incrementing raw pointers for each iterable rather than incrementing a single index variable. This elides safety checks without any analysis required thanks to the length assertion and lowers to decent machine code even in debug builds. - An "end" value is selected, prioritizing a counter if possible, falling back to a runtime calculation of ptr+len on a slice input. * Specialize on the pattern `0..`, avoiding an unnecessary subtraction instruction being emitted. * Add the `for_check_lens` ZIR instruction.
1 parent 6733e43
src/AstGen.zig
@@ -2666,6 +2666,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
             .validate_deref,
             .save_err_ret_index,
             .restore_err_ret_index,
+            .for_check_lens,
             => break :b true,
 
             .@"defer" => unreachable,
@@ -6294,37 +6295,35 @@ fn forExpr(
         try astgen.checkLabelRedefinition(scope, label_token);
     }
 
-    // Set up variables and constants.
     const is_inline = parent_gz.force_comptime or for_full.inline_token != null;
     const tree = astgen.tree;
     const token_tags = tree.tokens.items(.tag);
     const node_tags = tree.nodes.items(.tag);
     const node_data = tree.nodes.items(.data);
+    const gpa = astgen.gpa;
 
-    // Check for unterminated ranges.
-    {
-        var unterminated: ?Ast.Node.Index = null;
-        for (for_full.ast.inputs) |input| {
-            if (node_tags[input] != .for_range) break;
-            if (node_data[input].rhs != 0) break;
-            unterminated = unterminated orelse input;
-        } else {
-            return astgen.failNode(unterminated.?, "unterminated for range", .{});
-        }
-    }
-
-    var lens = astgen.gpa.alloc(Zir.Inst.Ref, for_full.ast.inputs.len);
-    defer astgen.gpa.free(lens);
-    var indexables = astgen.gpa.alloc(Zir.Inst.Ref, for_full.ast.inputs.len);
-    defer astgen.gpa.free(indexables);
-    var counters = std.ArrayList(Zir.Inst.Ref).init(astgen.gpa);
-    defer counters.deinit();
+    const allocs = try gpa.alloc(Zir.Inst.Ref, for_full.ast.inputs.len);
+    defer gpa.free(allocs);
+    // elements of this array can be `none`, indicating no length check.
+    const lens = try gpa.alloc(Zir.Inst.Ref, for_full.ast.inputs.len);
+    defer gpa.free(lens);
 
     const counter_alloc_tag: Zir.Inst.Tag = if (is_inline) .alloc_comptime_mut else .alloc;
 
+    // Tracks the index of allocs/lens that has a length to be checked and is
+    // used for the end value.
+    // If this is null, there are no len checks.
+    var end_input_index: ?u32 = null;
+    // This is a value to use to find out if the for loop has reached the end
+    // yet. It prefers to use a counter since the end value is provided directly,
+    // and otherwise falls back to adding ptr+len of a slice to compute end.
+    // Corresponds to end_input_index and will be .none in case that value is null.
+    var cond_end_val: Zir.Inst.Ref = .none;
+
     {
         var payload = for_full.payload_token;
-        for (for_full.ast.inputs) |input, i| {
+        for (for_full.ast.inputs) |input, i_usize| {
+            const i = @intCast(u32, i_usize);
             const payload_is_ref = token_tags[payload] == .asterisk;
             const ident_tok = payload + @boolToInt(payload_is_ref);
 
@@ -6339,59 +6338,101 @@ fn forExpr(
                     return astgen.failTok(ident_tok, "cannot capture reference to range", .{});
                 }
                 const counter_ptr = try parent_gz.addUnNode(counter_alloc_tag, .usize_type, node);
-                const start_val = try expr(parent_gz, scope, node_data[input].lhs, input);
+                const start_node = node_data[input].lhs;
+                const start_val = try expr(parent_gz, scope, .{ .rl = .none }, start_node);
                 _ = try parent_gz.addBin(.store, counter_ptr, start_val);
-                indexables[i] = counter_ptr;
-                try counters.append(counter_ptr);
 
                 const end_node = node_data[input].rhs;
-                const end_val = if (end_node != 0) try expr(parent_gz, scope, node_data[input].rhs, input) else .none;
-                const range_len = try parent_gz.addPlNode(.for_range_len, input, Zir.Inst.Bin{
-                    .lhs = start_val,
-                    .rhs = end_val,
-                });
+                const end_val = if (end_node != 0)
+                    try expr(parent_gz, scope, .{ .rl = .none }, node_data[input].rhs)
+                else
+                    .none;
+
+                const range_len = if (end_val == .none or nodeIsTriviallyZero(tree, start_node))
+                    end_val
+                else
+                    try parent_gz.addPlNode(.sub, input, Zir.Inst.Bin{
+                        .lhs = end_val,
+                        .rhs = start_val,
+                    });
+
+                if (range_len != .none and cond_end_val == .none) {
+                    end_input_index = i;
+                    cond_end_val = end_val;
+                }
+
+                allocs[i] = counter_ptr;
                 lens[i] = range_len;
             } else {
                 const cond_ri: ResultInfo = .{ .rl = if (payload_is_ref) .ref else .none };
                 const indexable = try expr(parent_gz, scope, cond_ri, input);
-                indexables[i] = indexable;
+                const base_ptr = try parent_gz.addPlNode(.elem_ptr_imm, input, Zir.Inst.ElemPtrImm{
+                    .ptr = indexable,
+                    .index = 0,
+                });
 
-                const indexable_len = try parent_gz.addUnNode(.indexable_ptr_len, indexable, input);
-                lens[i] = indexable_len;
+                if (end_input_index == null) {
+                    end_input_index = i;
+                    assert(cond_end_val == .none);
+                }
+
+                allocs[i] = base_ptr;
+                lens[i] = try parent_gz.addUnNode(.indexable_ptr_len, indexable, input);
             }
         }
     }
 
-    const len = "check_for_lens";
+    // In case there are no counters which already have an end computed, we
+    // compute an end from base pointer plus length.
+    if (end_input_index) |i| {
+        if (cond_end_val == .none) {
+            cond_end_val = try parent_gz.addPlNode(.add, for_full.ast.inputs[i], Zir.Inst.Bin{
+                .lhs = allocs[i],
+                .rhs = lens[i],
+            });
+        }
+    }
 
-    const index_ptr = blk: {
-        // Future optimization:
-        // for loops with only ranges don't need a separate index variable.
-        const index_ptr = try parent_gz.addUnNode(counter_alloc_tag, .usize_type, node);
-        // initialize to zero
-        _ = try parent_gz.addBin(.store, index_ptr, .zero_usize);
-        try counters.append(index_ptr);
-        break :blk index_ptr;
-    };
+    // We use a dedicated ZIR instruction to assert the lengths to assist with
+    // nicer error reporting as well as fewer ZIR bytes emitted.
+    if (end_input_index != null) {
+        const lens_len = @intCast(u32, lens.len);
+        try astgen.extra.ensureUnusedCapacity(gpa, @typeInfo(Zir.Inst.MultiOp).Struct.fields.len + lens_len);
+        _ = try parent_gz.addPlNode(.for_check_lens, node, Zir.Inst.MultiOp{
+            .operands_len = lens_len,
+        });
+        appendRefsAssumeCapacity(astgen, lens);
+    }
 
     const loop_tag: Zir.Inst.Tag = if (is_inline) .block_inline else .loop;
     const loop_block = try parent_gz.makeBlockInst(loop_tag, node);
-    try parent_gz.instructions.append(astgen.gpa, loop_block);
+    try parent_gz.instructions.append(gpa, loop_block);
 
     var loop_scope = parent_gz.makeSubBlock(scope);
     loop_scope.is_inline = is_inline;
     loop_scope.setBreakResultInfo(ri);
     defer loop_scope.unstack();
-    defer loop_scope.labeled_breaks.deinit(astgen.gpa);
+    defer loop_scope.labeled_breaks.deinit(gpa);
 
     var cond_scope = parent_gz.makeSubBlock(&loop_scope.base);
     defer cond_scope.unstack();
 
-    // check condition i < array_expr.len
-    const index = try cond_scope.addUnNode(.load, index_ptr, for_full.ast.cond_expr);
-    const cond = try cond_scope.addPlNode(.cmp_lt, for_full.ast.cond_expr, Zir.Inst.Bin{
-        .lhs = index,
-        .rhs = len,
+    // Load all the iterables.
+    const loaded_ptrs = try gpa.alloc(Zir.Inst.Ref, allocs.len);
+    defer gpa.free(loaded_ptrs);
+    for (allocs) |alloc, i| {
+        loaded_ptrs[i] = try cond_scope.addUnNode(.load, alloc, for_full.ast.inputs[i]);
+    }
+
+    // Check the condition.
+    const input_index = end_input_index orelse {
+        return astgen.failNode(node, "TODO: handle infinite for loop", .{});
+    };
+    assert(cond_end_val != .none);
+
+    const cond = try cond_scope.addPlNode(.cmp_neq, for_full.ast.inputs[input_index], Zir.Inst.Bin{
+        .lhs = loaded_ptrs[input_index],
+        .rhs = cond_end_val,
     });
 
     const condbr_tag: Zir.Inst.Tag = if (is_inline) .condbr_inline else .condbr;
@@ -6400,16 +6441,15 @@ fn forExpr(
     const cond_block = try loop_scope.makeBlockInst(block_tag, node);
     try cond_scope.setBlockBody(cond_block);
     // cond_block unstacked now, can add new instructions to loop_scope
-    try loop_scope.instructions.append(astgen.gpa, cond_block);
+    try loop_scope.instructions.append(gpa, cond_block);
 
-    // Increment the index variable and ranges.
-    for (counters) |counter_ptr| {
-        const counter = try loop_scope.addUnNode(.load, counter_ptr, for_full.ast.cond_expr);
-        const counter_plus_one = try loop_scope.addPlNode(.add, node, Zir.Inst.Bin{
-            .lhs = counter,
+    // Increment the loop variables.
+    for (allocs) |alloc, i| {
+        const incremented = try loop_scope.addPlNode(.add, node, Zir.Inst.Bin{
+            .lhs = loaded_ptrs[i],
             .rhs = .one_usize,
         });
-        _ = try loop_scope.addBin(.store, counter_ptr, counter_plus_one);
+        _ = try loop_scope.addBin(.store, alloc, incremented);
     }
     const repeat_tag: Zir.Inst.Tag = if (is_inline) .repeat_inline else .repeat;
     _ = try loop_scope.addNode(repeat_tag, node);
@@ -8960,6 +9000,25 @@ comptime {
     }
 }
 
+fn nodeIsTriviallyZero(tree: *const Ast, node: Ast.Node.Index) bool {
+    const node_tags = tree.nodes.items(.tag);
+    const main_tokens = tree.nodes.items(.main_token);
+
+    switch (node_tags[node]) {
+        .number_literal => {
+            const ident = main_tokens[node];
+            return switch (std.zig.parseNumberLiteral(tree.tokenSlice(ident))) {
+                .int => |number| switch (number) {
+                    0 => true,
+                    else => false,
+                },
+                else => false,
+            };
+        },
+        else => return false,
+    }
+}
+
 fn nodeMayNeedMemoryLocation(tree: *const Ast, start_node: Ast.Node.Index, have_res_ty: bool) bool {
     const node_tags = tree.nodes.items(.tag);
     const node_datas = tree.nodes.items(.data);
src/print_zir.zig
@@ -355,6 +355,8 @@ const Writer = struct {
             .array_type,
             => try self.writePlNodeBin(stream, inst),
 
+            .for_check_lens => try self.writePlNodeMultiOp(stream, inst),
+
             .elem_ptr_imm => try self.writeElemPtrImm(stream, inst),
 
             .@"export" => try self.writePlNodeExport(stream, inst),
@@ -868,6 +870,19 @@ const Writer = struct {
         try self.writeSrc(stream, inst_data.src());
     }
 
+    fn writePlNodeMultiOp(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
+        const inst_data = self.code.instructions.items(.data)[inst].pl_node;
+        const extra = self.code.extraData(Zir.Inst.MultiOp, inst_data.payload_index);
+        const args = self.code.refSlice(extra.end, extra.data.operands_len);
+        try stream.writeAll("{");
+        for (args) |arg, i| {
+            if (i != 0) try stream.writeAll(", ");
+            try self.writeInstRef(stream, arg);
+        }
+        try stream.writeAll("}) ");
+        try self.writeSrc(stream, inst_data.src());
+    }
+
     fn writeElemPtrImm(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
         const inst_data = self.code.instructions.items(.data)[inst].pl_node;
         const extra = self.code.extraData(Zir.Inst.ElemPtrImm, inst_data.payload_index).data;
src/Sema.zig
@@ -1386,6 +1386,11 @@ fn analyzeBodyInner(
                 i += 1;
                 continue;
             },
+            .for_check_lens => {
+                try sema.zirForCheckLens(block, inst);
+                i += 1;
+                continue;
+            },
 
             // Special case instructions to handle comptime control flow.
             .@"break" => {
@@ -17096,6 +17101,16 @@ fn zirRestoreErrRetIndex(sema: *Sema, start_block: *Block, inst: Zir.Inst.Index)
     return sema.popErrorReturnTrace(start_block, src, operand, saved_index);
 }
 
+fn zirForCheckLens(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!void {
+    const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
+    const extra = sema.code.extraData(Zir.Inst.MultiOp, inst_data.payload_index);
+    const args = sema.code.refSlice(extra.end, extra.data.operands_len);
+    const src = inst_data.src();
+
+    _ = args;
+    return sema.fail(block, src, "TODO implement zirForCheckLens", .{});
+}
+
 fn addToInferredErrorSet(sema: *Sema, uncasted_operand: Air.Inst.Ref) !void {
     assert(sema.fn_ret_ty.zigTypeTag() == .ErrorUnion);
 
src/Zir.zig
@@ -497,6 +497,15 @@ pub const Inst = struct {
         /// Sends comptime control flow back to the beginning of the current block.
         /// Uses the `node` field.
         repeat_inline,
+        /// Asserts that all the lengths provided match. Used to build a for loop.
+        /// Return value is always void.
+        /// Uses the `pl_node` field with payload `MultiOp`.
+        /// There is exactly one item corresponding to each AST node inside the for
+        /// loop condition. Each item may be `none`, indicating an unbounded range.
+        /// Illegal behaviors:
+        ///  * If all lengths are unbounded ranges (always a compile error).
+        ///  * If any two lengths do not match each other.
+        for_check_lens,
         /// Merge two error sets into one, `E1 || E2`.
         /// Uses the `pl_node` field with payload `Bin`.
         merge_error_sets,
@@ -1242,6 +1251,7 @@ pub const Inst = struct {
                 .defer_err_code,
                 .save_err_ret_index,
                 .restore_err_ret_index,
+                .for_check_lens,
                 => false,
 
                 .@"break",
@@ -1309,6 +1319,7 @@ pub const Inst = struct {
                 .memcpy,
                 .memset,
                 .check_comptime_control_flow,
+                .for_check_lens,
                 .@"defer",
                 .defer_err_code,
                 .restore_err_ret_index,
@@ -1588,6 +1599,7 @@ pub const Inst = struct {
                 .@"break" = .@"break",
                 .break_inline = .@"break",
                 .check_comptime_control_flow = .un_node,
+                .for_check_lens = .pl_node,
                 .call = .pl_node,
                 .cmp_lt = .pl_node,
                 .cmp_lte = .pl_node,