Commit 8b9161179d

Jacob Young <jacobly0@users.noreply.github.com>
2023-08-10 09:25:35
Sema: avoid deleting runtime side-effects in comptime initializers
Closes #16744
1 parent b835fd9
Changed files (4)
src/Sema.zig
@@ -4491,7 +4491,7 @@ fn validateUnionInit(
     _ = try sema.unionFieldIndex(block, union_ty, field_name, field_src);
     const air_tags = sema.air_instructions.items(.tag);
     const air_datas = sema.air_instructions.items(.data);
-    const field_ptr_air_ref = sema.inst_map.get(field_ptr).?;
+    const field_ptr_ref = sema.inst_map.get(field_ptr).?;
 
     // Our task here is to determine if the union is comptime-known. In such case,
     // we erase the runtime AIR instructions for initializing the union, and replace
@@ -4521,31 +4521,25 @@ fn validateUnionInit(
     var make_runtime = false;
     while (block_index > 0) : (block_index -= 1) {
         const store_inst = block.instructions.items[block_index];
-        if (Air.indexToRef(store_inst) == field_ptr_air_ref) break;
+        if (Air.indexToRef(store_inst) == field_ptr_ref) break;
         switch (air_tags[store_inst]) {
             .store, .store_safe => {},
             else => continue,
         }
         const bin_op = air_datas[store_inst].bin_op;
-        var lhs = bin_op.lhs;
-        if (Air.refToIndex(lhs)) |lhs_index| {
-            if (air_tags[lhs_index] == .bitcast) {
-                lhs = air_datas[lhs_index].ty_op.operand;
-                block_index -= 1;
-            }
-        }
-        if (lhs != field_ptr_air_ref) continue;
-        while (block_index > 0) : (block_index -= 1) {
-            const block_inst = block.instructions.items[block_index - 1];
-            if (air_tags[block_inst] != .dbg_stmt) break;
-        }
-        if (block_index > 0 and
-            field_ptr_air_ref == Air.indexToRef(block.instructions.items[block_index - 1]))
-        {
-            first_block_index = @min(first_block_index, block_index - 1);
-        } else {
-            first_block_index = @min(first_block_index, block_index);
-        }
+        var ptr_ref = bin_op.lhs;
+        if (Air.refToIndex(ptr_ref)) |ptr_inst| if (air_tags[ptr_inst] == .bitcast) {
+            ptr_ref = air_datas[ptr_inst].ty_op.operand;
+        };
+        if (ptr_ref != field_ptr_ref) continue;
+        first_block_index = @min(if (Air.refToIndex(field_ptr_ref)) |field_ptr_inst|
+            std.mem.lastIndexOfScalar(
+                Air.Inst.Index,
+                block.instructions.items[0..block_index],
+                field_ptr_inst,
+            ).?
+        else
+            block_index, first_block_index);
         init_val = try sema.resolveMaybeUndefValAllowVariablesMaybeRuntime(bin_op.rhs, &make_runtime);
         break;
     }
@@ -4557,7 +4551,29 @@ fn validateUnionInit(
     if (init_val) |val| {
         // Our task is to delete all the `field_ptr` and `store` instructions, and insert
         // instead a single `store` to the result ptr with a comptime union value.
-        block.instructions.shrinkRetainingCapacity(first_block_index);
+        block_index = first_block_index;
+        for (block.instructions.items[first_block_index..]) |cur_inst| {
+            switch (air_tags[cur_inst]) {
+                .struct_field_ptr,
+                .struct_field_ptr_index_0,
+                .struct_field_ptr_index_1,
+                .struct_field_ptr_index_2,
+                .struct_field_ptr_index_3,
+                => if (Air.indexToRef(cur_inst) == field_ptr_ref) continue,
+                .bitcast => if (air_datas[cur_inst].ty_op.operand == field_ptr_ref) continue,
+                .store, .store_safe => {
+                    var ptr_ref = air_datas[cur_inst].bin_op.lhs;
+                    if (Air.refToIndex(ptr_ref)) |ptr_inst| if (air_tags[ptr_inst] == .bitcast) {
+                        ptr_ref = air_datas[ptr_inst].ty_op.operand;
+                    };
+                    if (ptr_ref == field_ptr_ref) continue;
+                },
+                else => {},
+            }
+            block.instructions.items[block_index] = cur_inst;
+            block_index += 1;
+        }
+        block.instructions.shrinkRetainingCapacity(block_index);
 
         var union_val = try mod.intern(.{ .un = .{
             .ty = union_ty.toIntern(),
@@ -4590,6 +4606,9 @@ fn validateStructInit(
     const gpa = sema.gpa;
     const ip = &mod.intern_pool;
 
+    const field_indices = try gpa.alloc(u32, instrs.len);
+    defer gpa.free(field_indices);
+
     // Maps field index to field_ptr index of where it was already initialized.
     const found_fields = try gpa.alloc(Zir.Inst.Index, struct_ty.structFieldCount(mod));
     defer gpa.free(found_fields);
@@ -4597,7 +4616,7 @@ fn validateStructInit(
 
     var struct_ptr_zir_ref: Zir.Inst.Ref = undefined;
 
-    for (instrs) |field_ptr| {
+    for (instrs, field_indices) |field_ptr, *field_index| {
         const field_ptr_data = sema.code.instructions.items(.data)[field_ptr].pl_node;
         const field_src: LazySrcLoc = .{ .node_offset_initializer = field_ptr_data.src_node };
         const field_ptr_extra = sema.code.extraData(Zir.Inst.Field, field_ptr_data.payload_index).data;
@@ -4606,12 +4625,12 @@ fn validateStructInit(
             gpa,
             sema.code.nullTerminatedString(field_ptr_extra.field_name_start),
         );
-        const field_index = if (struct_ty.isTuple(mod))
+        field_index.* = if (struct_ty.isTuple(mod))
             try sema.tupleFieldIndex(block, struct_ty, field_name, field_src)
         else
             try sema.structFieldIndex(block, struct_ty, field_name, field_src);
-        if (found_fields[field_index] != 0) {
-            const other_field_ptr = found_fields[field_index];
+        if (found_fields[field_index.*] != 0) {
+            const other_field_ptr = found_fields[field_index.*];
             const other_field_ptr_data = sema.code.instructions.items(.data)[other_field_ptr].pl_node;
             const other_field_src: LazySrcLoc = .{ .node_offset_initializer = other_field_ptr_data.src_node };
             const msg = msg: {
@@ -4622,7 +4641,7 @@ fn validateStructInit(
             };
             return sema.failWithOwnedErrorMsg(msg);
         }
-        found_fields[field_index] = field_ptr;
+        found_fields[field_index.*] = field_ptr;
     }
 
     var root_msg: ?*Module.ErrorMsg = null;
@@ -4708,7 +4727,7 @@ fn validateStructInit(
                 continue;
             }
 
-            const field_ptr_air_ref = sema.inst_map.get(field_ptr).?;
+            const field_ptr_ref = sema.inst_map.get(field_ptr).?;
 
             //std.debug.print("validateStructInit (field_ptr_air_inst=%{d}):\n", .{
             //    field_ptr_air_inst,
@@ -4738,7 +4757,7 @@ fn validateStructInit(
             var block_index = block.instructions.items.len - 1;
             while (block_index > 0) : (block_index -= 1) {
                 const store_inst = block.instructions.items[block_index];
-                if (Air.indexToRef(store_inst) == field_ptr_air_ref) {
+                if (Air.indexToRef(store_inst) == field_ptr_ref) {
                     struct_is_comptime = false;
                     continue :field;
                 }
@@ -4747,26 +4766,19 @@ fn validateStructInit(
                     else => continue,
                 }
                 const bin_op = air_datas[store_inst].bin_op;
-                var lhs = bin_op.lhs;
-                {
-                    const lhs_index = Air.refToIndex(lhs) orelse continue;
-                    if (air_tags[lhs_index] == .bitcast) {
-                        lhs = air_datas[lhs_index].ty_op.operand;
-                        block_index -= 1;
-                    }
-                }
-                if (lhs != field_ptr_air_ref) continue;
-                while (block_index > 0) : (block_index -= 1) {
-                    const block_inst = block.instructions.items[block_index - 1];
-                    if (air_tags[block_inst] != .dbg_stmt) break;
-                }
-                if (block_index > 0 and
-                    field_ptr_air_ref == Air.indexToRef(block.instructions.items[block_index - 1]))
-                {
-                    first_block_index = @min(first_block_index, block_index - 1);
-                } else {
-                    first_block_index = @min(first_block_index, block_index);
-                }
+                var ptr_ref = bin_op.lhs;
+                if (Air.refToIndex(ptr_ref)) |ptr_inst| if (air_tags[ptr_inst] == .bitcast) {
+                    ptr_ref = air_datas[ptr_inst].ty_op.operand;
+                };
+                if (ptr_ref != field_ptr_ref) continue;
+                first_block_index = @min(if (Air.refToIndex(field_ptr_ref)) |field_ptr_inst|
+                    std.mem.lastIndexOfScalar(
+                        Air.Inst.Index,
+                        block.instructions.items[0..block_index],
+                        field_ptr_inst,
+                    ).?
+                else
+                    block_index, first_block_index);
                 if (try sema.resolveMaybeUndefValAllowVariablesMaybeRuntime(bin_op.rhs, &make_runtime)) |val| {
                     field_values[i] = val.toIntern();
                 } else if (require_comptime) {
@@ -4822,8 +4834,40 @@ fn validateStructInit(
     if (struct_is_comptime) {
         // Our task is to delete all the `field_ptr` and `store` instructions, and insert
         // instead a single `store` to the struct_ptr with a comptime struct value.
+        var init_index: usize = 0;
+        var field_ptr_ref = Air.Inst.Ref.none;
+        var block_index = first_block_index;
+        for (block.instructions.items[first_block_index..]) |cur_inst| {
+            while (field_ptr_ref == .none and init_index < instrs.len) : (init_index += 1) {
+                const field_ty = struct_ty.structFieldType(field_indices[init_index], mod);
+                if (try field_ty.onePossibleValue(mod)) |_| continue;
+                field_ptr_ref = sema.inst_map.get(instrs[init_index]).?;
+            }
+            switch (air_tags[cur_inst]) {
+                .struct_field_ptr,
+                .struct_field_ptr_index_0,
+                .struct_field_ptr_index_1,
+                .struct_field_ptr_index_2,
+                .struct_field_ptr_index_3,
+                => if (Air.indexToRef(cur_inst) == field_ptr_ref) continue,
+                .bitcast => if (air_datas[cur_inst].ty_op.operand == field_ptr_ref) continue,
+                .store, .store_safe => {
+                    var ptr_ref = air_datas[cur_inst].bin_op.lhs;
+                    if (Air.refToIndex(ptr_ref)) |ptr_inst| if (air_tags[ptr_inst] == .bitcast) {
+                        ptr_ref = air_datas[ptr_inst].ty_op.operand;
+                    };
+                    if (ptr_ref == field_ptr_ref) {
+                        field_ptr_ref = .none;
+                        continue;
+                    }
+                },
+                else => {},
+            }
+            block.instructions.items[block_index] = cur_inst;
+            block_index += 1;
+        }
+        block.instructions.shrinkRetainingCapacity(block_index);
 
-        block.instructions.shrinkRetainingCapacity(first_block_index);
         var struct_val = try mod.intern(.{ .aggregate = .{
             .ty = struct_ty.toIntern(),
             .storage = .{ .elems = field_values },
@@ -4950,7 +4994,7 @@ fn zirValidateArrayInit(
             }
         }
 
-        const elem_ptr_air_ref = sema.inst_map.get(elem_ptr).?;
+        const elem_ptr_ref = sema.inst_map.get(elem_ptr).?;
 
         // We expect to see something like this in the current block AIR:
         //   %a = elem_ptr(...)
@@ -4975,7 +5019,7 @@ fn zirValidateArrayInit(
         var block_index = block.instructions.items.len - 1;
         while (block_index > 0) : (block_index -= 1) {
             const store_inst = block.instructions.items[block_index];
-            if (Air.indexToRef(store_inst) == elem_ptr_air_ref) {
+            if (Air.indexToRef(store_inst) == elem_ptr_ref) {
                 array_is_comptime = false;
                 continue :outer;
             }
@@ -4984,26 +5028,19 @@ fn zirValidateArrayInit(
                 else => continue,
             }
             const bin_op = air_datas[store_inst].bin_op;
-            var lhs = bin_op.lhs;
-            {
-                const lhs_index = Air.refToIndex(lhs) orelse continue;
-                if (air_tags[lhs_index] == .bitcast) {
-                    lhs = air_datas[lhs_index].ty_op.operand;
-                    block_index -= 1;
-                }
-            }
-            if (lhs != elem_ptr_air_ref) continue;
-            while (block_index > 0) : (block_index -= 1) {
-                const block_inst = block.instructions.items[block_index - 1];
-                if (air_tags[block_inst] != .dbg_stmt) break;
-            }
-            if (block_index > 0 and
-                elem_ptr_air_ref == Air.indexToRef(block.instructions.items[block_index - 1]))
-            {
-                first_block_index = @min(first_block_index, block_index - 1);
-            } else {
-                first_block_index = @min(first_block_index, block_index);
-            }
+            var ptr_ref = bin_op.lhs;
+            if (Air.refToIndex(ptr_ref)) |ptr_inst| if (air_tags[ptr_inst] == .bitcast) {
+                ptr_ref = air_datas[ptr_inst].ty_op.operand;
+            };
+            if (ptr_ref != elem_ptr_ref) continue;
+            first_block_index = @min(if (Air.refToIndex(elem_ptr_ref)) |elem_ptr_inst|
+                std.mem.lastIndexOfScalar(
+                    Air.Inst.Index,
+                    block.instructions.items[0..block_index],
+                    elem_ptr_inst,
+                ).?
+            else
+                block_index, first_block_index);
             if (try sema.resolveMaybeUndefValAllowVariablesMaybeRuntime(bin_op.rhs, &make_runtime)) |val| {
                 element_vals[i] = val.toIntern();
             } else {
@@ -5028,7 +5065,33 @@ fn zirValidateArrayInit(
 
         // Our task is to delete all the `elem_ptr` and `store` instructions, and insert
         // instead a single `store` to the array_ptr with a comptime struct value.
-        block.instructions.shrinkRetainingCapacity(first_block_index);
+        var elem_index: usize = 0;
+        var elem_ptr_ref = Air.Inst.Ref.none;
+        var block_index = first_block_index;
+        for (block.instructions.items[first_block_index..]) |cur_inst| {
+            while (elem_ptr_ref == .none and elem_index < instrs.len) : (elem_index += 1) {
+                if (array_ty.isTuple(mod) and array_ty.structFieldIsComptime(elem_index, mod)) continue;
+                elem_ptr_ref = sema.inst_map.get(instrs[elem_index]).?;
+            }
+            switch (air_tags[cur_inst]) {
+                .ptr_elem_ptr => if (Air.indexToRef(cur_inst) == elem_ptr_ref) continue,
+                .bitcast => if (air_datas[cur_inst].ty_op.operand == elem_ptr_ref) continue,
+                .store, .store_safe => {
+                    var ptr_ref = air_datas[cur_inst].bin_op.lhs;
+                    if (Air.refToIndex(ptr_ref)) |ptr_inst| if (air_tags[ptr_inst] == .bitcast) {
+                        ptr_ref = air_datas[ptr_inst].ty_op.operand;
+                    };
+                    if (ptr_ref == elem_ptr_ref) {
+                        elem_ptr_ref = .none;
+                        continue;
+                    }
+                },
+                else => {},
+            }
+            block.instructions.items[block_index] = cur_inst;
+            block_index += 1;
+        }
+        block.instructions.shrinkRetainingCapacity(block_index);
 
         var array_val = try mod.intern(.{ .aggregate = .{
             .ty = array_ty.toIntern(),
test/behavior/array.zig
@@ -775,3 +775,27 @@ test "array init with no result pointer sets field result types" {
 
     try expect(y == x);
 }
+
+test "runtime side-effects in comptime-known array init" {
+    var side_effects: u4 = 0;
+    const init = [4]u4{
+        blk: {
+            side_effects += 1;
+            break :blk 1;
+        },
+        blk: {
+            side_effects += 2;
+            break :blk 2;
+        },
+        blk: {
+            side_effects += 4;
+            break :blk 4;
+        },
+        blk: {
+            side_effects += 8;
+            break :blk 8;
+        },
+    };
+    try expectEqual([4]u4{ 1, 2, 4, 8 }, init);
+    try expectEqual(@as(u4, std.math.maxInt(u4)), side_effects);
+}
test/behavior/struct.zig
@@ -1738,3 +1738,28 @@ test "struct init with no result pointer sets field result types" {
 
     try expect(y == x);
 }
+
+test "runtime side-effects in comptime-known struct init" {
+    var side_effects: u4 = 0;
+    const S = struct { a: u4, b: u4, c: u4, d: u4 };
+    const init = S{
+        .d = blk: {
+            side_effects += 8;
+            break :blk 8;
+        },
+        .c = blk: {
+            side_effects += 4;
+            break :blk 4;
+        },
+        .b = blk: {
+            side_effects += 2;
+            break :blk 2;
+        },
+        .a = blk: {
+            side_effects += 1;
+            break :blk 1;
+        },
+    };
+    try expectEqual(S{ .a = 1, .b = 2, .c = 4, .d = 8 }, init);
+    try expectEqual(@as(u4, std.math.maxInt(u4)), side_effects);
+}
tools/lldb_pretty_printers.py
@@ -347,15 +347,9 @@ class TagAndPayload_SynthProvider:
         except: return -1
     def get_child_at_index(self, index): return (self.tag, self.payload)[index] if index in range(2) else None
 
-def Zir_Inst__Zir_Inst_Ref_SummaryProvider(value, _=None):
-    members = value.type.enum_members
-    # ignore .var_args_param_type and .none
-    return value if any(value.unsigned == member.unsigned for member in members) else 'instructions[%d]' % (value.unsigned + 2 - len(members))
-
-def Air_Inst__Air_Inst_Ref_SummaryProvider(value, _=None):
-    members = value.type.enum_members
-    # ignore .var_args_param_type and .none
-    return value if any(value.unsigned == member.unsigned for member in members) else 'instructions[%d]' % (value.unsigned + 2 - len(members))
+def InstRef_SummaryProvider(value, _=None):
+    return value if any(value.unsigned == member.unsigned for member in value.type.enum_members) else (
+        'InternPool.Index(%d)' % value.unsigned if value.unsigned < 0x80000000 else 'instructions[%d]' % (value.unsigned - 0x80000000))
 
 class Module_Decl__Module_Decl_Index_SynthProvider:
     def __init__(self, value, _=None): self.value = value
@@ -700,9 +694,9 @@ def __lldb_init_module(debugger, _=None):
     add(debugger, category='zig.stage2', type='Zir.Inst', identifier='TagAndPayload', synth=True, inline_children=True, summary=True)
     add(debugger, category='zig.stage2', regex=True, type=MultiArrayList_Entry('Zir\\.Inst'), identifier='TagAndPayload', synth=True, inline_children=True, summary=True)
     add(debugger, category='zig.stage2', regex=True, type='^Zir\\.Inst\\.Data\\.Data__struct_[1-9][0-9]*$', inline_children=True, summary=True)
-    add(debugger, category='zig.stage2', type='Zir.Inst::Zir.Inst.Ref', summary=True)
+    add(debugger, category='zig.stage2', type='Zir.Inst::Zir.Inst.Ref', identifier='InstRef', summary=True)
     add(debugger, category='zig.stage2', type='Air.Inst', identifier='TagAndPayload', synth=True, inline_children=True, summary=True)
-    add(debugger, category='zig.stage2', type='Air.Inst::Air.Inst.Ref', summary=True)
+    add(debugger, category='zig.stage2', type='Air.Inst::Air.Inst.Ref', identifier='InstRef', summary=True)
     add(debugger, category='zig.stage2', regex=True, type=MultiArrayList_Entry('Air\\.Inst'), identifier='TagAndPayload', synth=True, inline_children=True, summary=True)
     add(debugger, category='zig.stage2', regex=True, type='^Air\\.Inst\\.Data\\.Data__struct_[1-9][0-9]*$', inline_children=True, summary=True)
     add(debugger, category='zig.stage2', type='Module.Decl::Module.Decl.Index', synth=True)