Commit 4618c41fa6

Andrew Kelley <andrew@ziglang.org>
2022-04-03 02:59:07
Sema: mechanism for converting comptime breaks to runtime
closes #11369
1 parent 83bb98e
Changed files (2)
src
test
behavior
src/Sema.zig
@@ -68,6 +68,11 @@ preallocated_new_func: ?*Module.Fn = null,
 /// TODO: after upgrading to use InternPool change the key here to be an
 /// InternPool value index.
 types_to_resolve: std.ArrayListUnmanaged(Air.Inst.Ref) = .{},
+/// These are lazily created runtime blocks from inline_block instructions.
+/// They are created when an inline_break passes through a runtime condition, because
+/// Sema must convert comptime control flow to runtime control flow, which means
+/// breaking from a block.
+post_hoc_blocks: std.AutoHashMapUnmanaged(Air.Inst.Index, *LabeledBlock) = .{},
 
 const std = @import("std");
 const mem = std.mem;
@@ -525,6 +530,18 @@ pub const Block = struct {
     };
 };
 
+const LabeledBlock = struct {
+    block: Block,
+    label: Block.Label,
+
+    fn destroy(lb: *LabeledBlock, gpa: Allocator) void {
+        lb.block.instructions.deinit(gpa);
+        lb.label.merges.results.deinit(gpa);
+        lb.label.merges.br_list.deinit(gpa);
+        gpa.destroy(lb);
+    }
+};
+
 pub fn deinit(sema: *Sema) void {
     const gpa = sema.gpa;
     sema.air_instructions.deinit(gpa);
@@ -533,6 +550,14 @@ pub fn deinit(sema: *Sema) void {
     sema.inst_map.deinit(gpa);
     sema.decl_val_table.deinit(gpa);
     sema.types_to_resolve.deinit(gpa);
+    {
+        var it = sema.post_hoc_blocks.iterator();
+        while (it.next()) |entry| {
+            const labeled_block = entry.value_ptr.*;
+            labeled_block.destroy(gpa);
+        }
+        sema.post_hoc_blocks.deinit(gpa);
+    }
     sema.* = undefined;
 }
 
@@ -573,8 +598,8 @@ pub fn analyzeBody(
 
 const BreakData = struct {
     block_inst: Zir.Inst.Index,
-    operand: Air.Inst.Ref,
-    inst: Air.Inst.Index,
+    operand: Zir.Inst.Ref,
+    inst: Zir.Inst.Index,
 };
 
 pub fn analyzeBodyBreak(
@@ -1192,20 +1217,67 @@ fn analyzeBodyInner(
             },
             .block_inline => blk: {
                 // Directly analyze the block body without introducing a new block.
+                // However, in the case of a corresponding break_inline which reaches
+                // through a runtime conditional branch, we must retroactively emit
+                // a block, so we remember the block index here just in case.
+                const block_index = block.instructions.items.len;
                 const inst_data = datas[inst].pl_node;
                 const extra = sema.code.extraData(Zir.Inst.Block, inst_data.payload_index);
                 const inline_body = sema.code.extra[extra.end..][0..extra.data.body_len];
+                const gpa = sema.gpa;
                 // If this block contains a function prototype, we need to reset the
                 // current list of parameters and restore it later.
                 // Note: this probably needs to be resolved in a more general manner.
                 const prev_params = block.params;
                 block.params = .{};
                 defer {
-                    block.params.deinit(sema.gpa);
+                    block.params.deinit(gpa);
                     block.params = prev_params;
                 }
-                const break_data = (try sema.analyzeBodyBreak(block, inline_body)) orelse
-                    break always_noreturn;
+                const opt_break_data = try sema.analyzeBodyBreak(block, inline_body);
+                // A runtime conditional branch that needs a post-hoc block to be
+                // emitted communicates this by mapping the block index into the inst map.
+                if (map.get(inst)) |new_block_ref| ph: {
+                    // Comptime control flow populates the map, so we don't actually know
+                    // if this is a post-hoc runtime block until we check the
+                    // post_hoc_block map.
+                    const new_block_inst = Air.refToIndex(new_block_ref) orelse break :ph;
+                    const labeled_block = sema.post_hoc_blocks.get(new_block_inst) orelse
+                        break :ph;
+
+                    // In this case we need to move all the instructions starting at
+                    // block_index from the current block into this new one.
+
+                    if (opt_break_data) |break_data| {
+                        // This is a comptime break which we now change to a runtime break
+                        // since it crosses a runtime branch.
+                        // It may pass through our currently being analyzed block_inline or it
+                        // may point directly to it. In the latter case, this modifies the
+                        // block that we are about to look up in the post_hoc_blocks map below.
+                        try sema.addRuntimeBreak(block, break_data);
+                    } else {
+                        // Here the comptime control flow ends with noreturn; however
+                        // we have runtime control flow continuing after this block.
+                        // This branch is therefore handled by the `i += 1; continue;`
+                        // logic below.
+                    }
+
+                    try labeled_block.block.instructions.appendSlice(gpa, block.instructions.items[block_index..]);
+                    block.instructions.items.len = block_index;
+
+                    const block_result = try sema.analyzeBlockBody(block, inst_data.src(), &labeled_block.block, &labeled_block.label.merges);
+                    {
+                        // Destroy the ad-hoc block entry so that it does not interfere with
+                        // the next iteration of comptime control flow, if any.
+                        labeled_block.destroy(gpa);
+                        assert(sema.post_hoc_blocks.remove(new_block_inst));
+                    }
+                    try map.put(gpa, inst, block_result);
+                    i += 1;
+                    continue;
+                }
+
+                const break_data = opt_break_data orelse break always_noreturn;
                 if (inst == break_data.block_inst) {
                     break :blk sema.resolveInst(break_data.operand);
                 } else {
@@ -3996,13 +4068,12 @@ fn zirBlock(sema: *Sema, parent_block: *Block, inst: Zir.Inst.Index) CompileErro
         .inlining = parent_block.inlining,
         .is_comptime = parent_block.is_comptime,
     };
-    const merges = &child_block.label.?.merges;
 
     defer child_block.instructions.deinit(gpa);
-    defer merges.results.deinit(gpa);
-    defer merges.br_list.deinit(gpa);
+    defer label.merges.results.deinit(gpa);
+    defer label.merges.br_list.deinit(gpa);
 
-    return sema.resolveBlockBody(parent_block, src, &child_block, body, inst, merges);
+    return sema.resolveBlockBody(parent_block, src, &child_block, body, inst, &label.merges);
 }
 
 fn resolveBlockBody(
@@ -7955,7 +8026,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         const item = sema.resolveInst(item_ref);
         // `item` is already guaranteed to be constant known.
 
-        try sema.analyzeBody(&case_block, body);
+        _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) {
+            error.ComptimeBreak => {
+                const zir_datas = sema.code.instructions.items(.data);
+                const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                try sema.addRuntimeBreak(&case_block, .{
+                    .block_inst = break_data.block_inst,
+                    .operand = break_data.operand,
+                    .inst = sema.comptime_break_inst,
+                });
+            },
+            else => |e| return e,
+        };
 
         try wip_captures.finalize();
 
@@ -7998,7 +8080,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
 
             const body = sema.code.extra[extra_index..][0..body_len];
             extra_index += body_len;
-            try sema.analyzeBody(&case_block, body);
+            _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) {
+                error.ComptimeBreak => {
+                    const zir_datas = sema.code.instructions.items(.data);
+                    const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                    try sema.addRuntimeBreak(&case_block, .{
+                        .block_inst = break_data.block_inst,
+                        .operand = break_data.operand,
+                        .inst = sema.comptime_break_inst,
+                    });
+                },
+                else => |e| return e,
+            };
 
             try cases_extra.ensureUnusedCapacity(gpa, 2 + items.len +
                 case_block.instructions.items.len);
@@ -8073,7 +8166,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
 
             const body = sema.code.extra[extra_index..][0..body_len];
             extra_index += body_len;
-            try sema.analyzeBody(&case_block, body);
+            _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) {
+                error.ComptimeBreak => {
+                    const zir_datas = sema.code.instructions.items(.data);
+                    const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                    try sema.addRuntimeBreak(&case_block, .{
+                        .block_inst = break_data.block_inst,
+                        .operand = break_data.operand,
+                        .inst = sema.comptime_break_inst,
+                    });
+                },
+                else => |e| return e,
+            };
 
             try wip_captures.finalize();
 
@@ -8110,7 +8214,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         case_block.wip_capture_scope = wip_captures.scope;
 
         if (special.body.len != 0) {
-            try sema.analyzeBody(&case_block, special.body);
+            _ = sema.analyzeBodyInner(&case_block, special.body) catch |err| switch (err) {
+                error.ComptimeBreak => {
+                    const zir_datas = sema.code.instructions.items(.data);
+                    const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                    try sema.addRuntimeBreak(&case_block, .{
+                        .block_inst = break_data.block_inst,
+                        .operand = break_data.operand,
+                        .inst = sema.comptime_break_inst,
+                    });
+                },
+                else => |e| return e,
+            };
         } else {
             // We still need a terminator in this block, but we have proven
             // that it is unreachable.
@@ -11979,11 +12094,33 @@ fn zirCondbr(
     sub_block.runtime_index += 1;
     defer sub_block.instructions.deinit(gpa);
 
-    try sema.analyzeBody(&sub_block, then_body);
+    _ = sema.analyzeBodyInner(&sub_block, then_body) catch |err| switch (err) {
+        error.ComptimeBreak => {
+            const zir_datas = sema.code.instructions.items(.data);
+            const break_data = zir_datas[sema.comptime_break_inst].@"break";
+            try sema.addRuntimeBreak(&sub_block, .{
+                .block_inst = break_data.block_inst,
+                .operand = break_data.operand,
+                .inst = sema.comptime_break_inst,
+            });
+        },
+        else => |e| return e,
+    };
     const true_instructions = sub_block.instructions.toOwnedSlice(gpa);
     defer gpa.free(true_instructions);
 
-    try sema.analyzeBody(&sub_block, else_body);
+    _ = sema.analyzeBodyInner(&sub_block, else_body) catch |err| switch (err) {
+        error.ComptimeBreak => {
+            const zir_datas = sema.code.instructions.items(.data);
+            const break_data = zir_datas[sema.comptime_break_inst].@"break";
+            try sema.addRuntimeBreak(&sub_block, .{
+                .block_inst = break_data.block_inst,
+                .operand = break_data.operand,
+                .inst = sema.comptime_break_inst,
+            });
+        },
+        else => |e| return e,
+    };
     try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.CondBr).Struct.fields.len +
         true_instructions.len + sub_block.instructions.items.len);
     _ = try parent_block.addInst(.{
@@ -12001,6 +12138,61 @@ fn zirCondbr(
     return always_noreturn;
 }
 
+// A `break` statement is inside a runtime condition, but trying to
+// break from an inline loop. In such case we must convert it to
+// a runtime break.
+fn addRuntimeBreak(sema: *Sema, child_block: *Block, break_data: BreakData) !void {
+    const gop = try sema.inst_map.getOrPut(sema.gpa, break_data.block_inst);
+    const labeled_block = if (!gop.found_existing) blk: {
+        try sema.post_hoc_blocks.ensureUnusedCapacity(sema.gpa, 1);
+
+        const new_block_inst = @intCast(Air.Inst.Index, sema.air_instructions.len);
+        gop.value_ptr.* = Air.indexToRef(new_block_inst);
+        try sema.air_instructions.append(sema.gpa, .{
+            .tag = .block,
+            .data = undefined,
+        });
+        const labeled_block = try sema.gpa.create(LabeledBlock);
+        labeled_block.* = .{
+            .label = .{
+                .zir_block = break_data.block_inst,
+                .merges = .{
+                    .results = .{},
+                    .br_list = .{},
+                    .block_inst = new_block_inst,
+                },
+            },
+            .block = .{
+                .parent = child_block,
+                .sema = sema,
+                .src_decl = child_block.src_decl,
+                .namespace = child_block.namespace,
+                .wip_capture_scope = child_block.wip_capture_scope,
+                .instructions = .{},
+                .label = &labeled_block.label,
+                .inlining = child_block.inlining,
+                .is_comptime = child_block.is_comptime,
+            },
+        };
+        sema.post_hoc_blocks.putAssumeCapacityNoClobber(new_block_inst, labeled_block);
+        break :blk labeled_block;
+    } else blk: {
+        const new_block_inst = Air.refToIndex(gop.value_ptr.*).?;
+        const labeled_block = sema.post_hoc_blocks.get(new_block_inst).?;
+        break :blk labeled_block;
+    };
+
+    const operand = sema.resolveInst(break_data.operand);
+    const br_ref = try child_block.addBr(labeled_block.label.merges.block_inst, operand);
+    try labeled_block.label.merges.results.append(sema.gpa, operand);
+    try labeled_block.label.merges.br_list.append(sema.gpa, Air.refToIndex(br_ref).?);
+    labeled_block.block.runtime_index += 1;
+    if (labeled_block.block.runtime_cond == null and labeled_block.block.runtime_loop == null) {
+        labeled_block.block.runtime_cond = child_block.runtime_cond orelse child_block.runtime_loop;
+        labeled_block.block.runtime_loop = child_block.runtime_loop;
+    }
+}
+
 fn zirUnreachable(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Zir.Inst.Index {
     const tracy = trace(@src());
     defer tracy.end();
test/behavior/eval.zig
@@ -893,3 +893,108 @@ test "closure capture type of runtime-known parameter" {
     var c: i32 = 1234;
     try S.b(c);
 }
+
+test "comptime break passing through runtime condition converted to runtime break" {
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+
+    const S = struct {
+        fn doTheTest() !void {
+            var runtime: u8 = 'b';
+            inline for ([3]u8{ 'a', 'b', 'c' }) |byte| {
+                bar();
+                if (byte == runtime) {
+                    foo(byte);
+                    break;
+                }
+            }
+            try expect(ok);
+            try expect(count == 2);
+        }
+        var ok = false;
+        var count: usize = 0;
+
+        fn foo(byte: u8) void {
+            ok = byte == 'b';
+        }
+
+        fn bar() void {
+            count += 1;
+        }
+    };
+
+    try S.doTheTest();
+}
+
+test "comptime break to outer loop passing through runtime condition converted to runtime break" {
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+
+    const S = struct {
+        fn doTheTest() !void {
+            var runtime: u8 = 'b';
+            outer: inline for ([3]u8{ 'A', 'B', 'C' }) |outer_byte| {
+                inline for ([3]u8{ 'a', 'b', 'c' }) |byte| {
+                    bar(outer_byte);
+                    if (byte == runtime) {
+                        foo(byte);
+                        break :outer;
+                    }
+                }
+            }
+            try expect(ok);
+            try expect(count == 2);
+        }
+        var ok = false;
+        var count: usize = 0;
+
+        fn foo(byte: u8) void {
+            ok = byte == 'b';
+        }
+
+        fn bar(byte: u8) void {
+            _ = byte;
+            count += 1;
+        }
+    };
+
+    try S.doTheTest();
+}
+
+test "comptime break operand passing through runtime condition converted to runtime break" {
+    const S = struct {
+        fn doTheTest(runtime: u8) !void {
+            const result = inline for ([3]u8{ 'a', 'b', 'c' }) |byte| {
+                if (byte == runtime) {
+                    break runtime;
+                }
+            } else 'z';
+            try expect(result == 'b');
+        }
+    };
+
+    try S.doTheTest('b');
+    comptime try S.doTheTest('b');
+}
+
+test "comptime break operand passing through runtime switch converted to runtime break" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+
+    const S = struct {
+        fn doTheTest(runtime: u8) !void {
+            const result = inline for ([3]u8{ 'a', 'b', 'c' }) |byte| {
+                switch (runtime) {
+                    byte => break runtime,
+                    else => {},
+                }
+            } else 'z';
+            try expect(result == 'b');
+        }
+    };
+
+    try S.doTheTest('b');
+    comptime try S.doTheTest('b');
+}