Commit cb5d2b691a

Veikka Tuominen <git@vexu.eu>
2022-06-04 12:12:55
Sema: validate equality on store to comptime field
1 parent a040ccb
Changed files (5)
src/Sema.zig
@@ -3616,7 +3616,7 @@ fn zirValidateArrayInit(
     const air_tags = sema.air_instructions.items(.tag);
     const air_datas = sema.air_instructions.items(.data);
 
-    for (instrs) |elem_ptr, i| {
+    outer: for (instrs) |elem_ptr, i| {
         const elem_ptr_data = sema.code.instructions.items(.data)[elem_ptr].pl_node;
         const elem_src: LazySrcLoc = .{ .node_offset = elem_ptr_data.src_node };
 
@@ -3630,6 +3630,10 @@ fn zirValidateArrayInit(
         // of the for loop.
         var block_index = block.instructions.items.len - 1;
         while (block.instructions.items[block_index] != elem_ptr_air_inst) {
+            if (block_index == 0) {
+                array_is_comptime = true;
+                continue :outer;
+            }
             block_index -= 1;
         }
         first_block_index = @minimum(first_block_index, block_index);
@@ -3672,6 +3676,13 @@ fn zirValidateArrayInit(
     }
 
     if (array_is_comptime) {
+        if (try sema.resolveDefinedValue(block, init_src, array_ptr)) |ptr_val| {
+            if (ptr_val.tag() == .comptime_field_ptr) {
+                // This store was validated by the individual elem ptrs.
+                return;
+            }
+        }
+
         // Our task is to delete all the `elem_ptr` and `store` instructions, and insert
         // instead a single `store` to the array_ptr with a comptime struct value.
         // Also to populate the sentinel value, if any.
@@ -18462,14 +18473,11 @@ fn structFieldPtrByIndex(
     const ptr_field_ty = try Type.ptr(sema.arena, sema.mod, ptr_ty_data);
 
     if (field.is_comptime) {
-        var anon_decl = try block.startAnonDecl(field_src);
-        defer anon_decl.deinit();
-        const decl = try anon_decl.finish(
-            try field.ty.copy(anon_decl.arena()),
-            try field.default_val.copy(anon_decl.arena()),
-            ptr_ty_data.@"align",
-        );
-        return sema.analyzeDeclRef(decl);
+        const val = try Value.Tag.comptime_field_ptr.create(sema.arena, .{
+            .field_ty = try field.ty.copy(sema.arena),
+            .field_val = try field.default_val.copy(sema.arena),
+        });
+        return sema.addConstant(ptr_field_ty, val);
     }
 
     if (try sema.resolveDefinedValue(block, src, struct_ptr)) |struct_ptr_val| {
@@ -20247,6 +20255,14 @@ fn storePtrVal(
 
     const bitcasted_val = try sema.bitCastVal(block, src, operand_val, operand_ty, mut_kit.ty, 0);
 
+    if (mut_kit.decl_ref_mut.runtime_index == std.math.maxInt(u32)) {
+        // Special case for comptime field ptr.
+        if (!mut_kit.val.eql(bitcasted_val, mut_kit.ty, sema.mod)) {
+            return sema.fail(block, src, "value stored in comptime field does not match the default value of the field", .{});
+        }
+        return;
+    }
+
     const arena = mut_kit.beginArena(sema.mod);
     defer mut_kit.finishArena(sema.mod);
 
@@ -20296,6 +20312,19 @@ fn beginComptimePtrMutation(
                 .ty = decl.ty,
             };
         },
+        .comptime_field_ptr => {
+            const payload = ptr_val.castTag(.comptime_field_ptr).?.data;
+            const duped = try sema.arena.create(Value);
+            duped.* = payload.field_val;
+            return ComptimePtrMutationKit{
+                .decl_ref_mut = .{
+                    .decl_index = @intToEnum(Module.Decl.Index, 0),
+                    .runtime_index = std.math.maxInt(u32),
+                },
+                .val = duped,
+                .ty = payload.field_ty,
+            };
+        },
         .elem_ptr => {
             const elem_ptr = ptr_val.castTag(.elem_ptr).?.data;
             var parent = try beginComptimePtrMutation(sema, block, src, elem_ptr.array_ptr);
src/TypedValue.zig
@@ -264,6 +264,16 @@ pub fn print(
                 .val = decl.val,
             }, writer, level - 1, mod);
         },
+        .comptime_field_ptr => {
+            const payload = val.castTag(.comptime_field_ptr).?.data;
+            if (level == 0) {
+                return writer.writeAll("(comptime field ptr)");
+            }
+            return print(.{
+                .ty = payload.field_ty,
+                .val = payload.field_val,
+            }, writer, level - 1, mod);
+        },
         .elem_ptr => {
             const elem_ptr = val.castTag(.elem_ptr).?.data;
             try writer.writeAll("&");
src/value.zig
@@ -120,6 +120,8 @@ pub const Value = extern union {
         /// This Tag will never be seen by machine codegen backends. It is changed into a
         /// `decl_ref` when a comptime variable goes out of scope.
         decl_ref_mut,
+        /// Behaves like `decl_ref_mut` but validates that the stored value matches the field value.
+        comptime_field_ptr,
         /// Pointer to a specific element of an array, vector or slice.
         elem_ptr,
         /// Pointer to a specific field of a struct or union.
@@ -316,6 +318,7 @@ pub const Value = extern union {
                 .aggregate => Payload.Aggregate,
                 .@"union" => Payload.Union,
                 .bound_fn => Payload.BoundFn,
+                .comptime_field_ptr => Payload.ComptimeFieldPtr,
             };
         }
 
@@ -506,6 +509,18 @@ pub const Value = extern union {
                 };
                 return Value{ .ptr_otherwise = &new_payload.base };
             },
+            .comptime_field_ptr => {
+                const payload = self.cast(Payload.ComptimeFieldPtr).?;
+                const new_payload = try arena.create(Payload.ComptimeFieldPtr);
+                new_payload.* = .{
+                    .base = payload.base,
+                    .data = .{
+                        .field_val = try payload.data.field_val.copy(arena),
+                        .field_ty = try payload.data.field_ty.copy(arena),
+                    },
+                };
+                return Value{ .ptr_otherwise = &new_payload.base };
+            },
             .elem_ptr => {
                 const payload = self.castTag(.elem_ptr).?;
                 const new_payload = try arena.create(Payload.ElemPtr);
@@ -754,6 +769,9 @@ pub const Value = extern union {
                 const decl_index = val.castTag(.decl_ref).?.data;
                 return out_stream.print("(decl_ref {d})", .{decl_index});
             },
+            .comptime_field_ptr => {
+                return out_stream.writeAll("(comptime_field_ptr)");
+            },
             .elem_ptr => {
                 const elem_ptr = val.castTag(.elem_ptr).?.data;
                 try out_stream.print("&[{}] ", .{elem_ptr.index});
@@ -1706,6 +1724,7 @@ pub const Value = extern union {
             .int_big_negative => return self.castTag(.int_big_negative).?.asBigInt().bitCountTwosComp(),
 
             .decl_ref_mut,
+            .comptime_field_ptr,
             .extern_fn,
             .decl_ref,
             .function,
@@ -1770,6 +1789,7 @@ pub const Value = extern union {
             .bool_true,
             .decl_ref,
             .decl_ref_mut,
+            .comptime_field_ptr,
             .extern_fn,
             .function,
             .variable,
@@ -2362,7 +2382,7 @@ pub const Value = extern union {
 
     pub fn isComptimeMutablePtr(val: Value) bool {
         return switch (val.tag()) {
-            .decl_ref_mut => true,
+            .decl_ref_mut, .comptime_field_ptr => true,
             .elem_ptr => isComptimeMutablePtr(val.castTag(.elem_ptr).?.data.array_ptr),
             .field_ptr => isComptimeMutablePtr(val.castTag(.field_ptr).?.data.container_ptr),
             .eu_payload_ptr => isComptimeMutablePtr(val.castTag(.eu_payload_ptr).?.data.container_ptr),
@@ -2426,6 +2446,9 @@ pub const Value = extern union {
                 const decl: Module.Decl.Index = ptr_val.pointerDecl().?;
                 std.hash.autoHash(hasher, decl);
             },
+            .comptime_field_ptr => {
+                std.hash.autoHash(hasher, Value.Tag.comptime_field_ptr);
+            },
 
             .elem_ptr => {
                 const elem_ptr = ptr_val.castTag(.elem_ptr).?.data;
@@ -2471,7 +2494,7 @@ pub const Value = extern union {
         return switch (val.tag()) {
             .slice => val.castTag(.slice).?.data.ptr,
             // TODO this should require being a slice tag, and not allow decl_ref, field_ptr, etc.
-            .decl_ref, .decl_ref_mut, .field_ptr, .elem_ptr => val,
+            .decl_ref, .decl_ref_mut, .field_ptr, .elem_ptr, .comptime_field_ptr => val,
             else => unreachable,
         };
     }
@@ -2497,6 +2520,14 @@ pub const Value = extern union {
                     return 1;
                 }
             },
+            .comptime_field_ptr => {
+                const payload = val.castTag(.comptime_field_ptr).?.data;
+                if (payload.field_ty.zigTypeTag() == .Array) {
+                    return payload.field_ty.arrayLen();
+                } else {
+                    return 1;
+                }
+            },
             else => unreachable,
         };
     }
@@ -2587,6 +2618,7 @@ pub const Value = extern union {
 
             .decl_ref => return mod.declPtr(val.castTag(.decl_ref).?.data).val.elemValueAdvanced(mod, index, arena, buffer),
             .decl_ref_mut => return mod.declPtr(val.castTag(.decl_ref_mut).?.data.decl_index).val.elemValueAdvanced(mod, index, arena, buffer),
+            .comptime_field_ptr => return val.castTag(.comptime_field_ptr).?.data.field_val.elemValueAdvanced(mod, index, arena, buffer),
             .elem_ptr => {
                 const data = val.castTag(.elem_ptr).?.data;
                 return data.array_ptr.elemValueAdvanced(mod, index + data.index, arena, buffer);
@@ -2623,6 +2655,7 @@ pub const Value = extern union {
 
             .decl_ref => sliceArray(mod.declPtr(val.castTag(.decl_ref).?.data).val, mod, arena, start, end),
             .decl_ref_mut => sliceArray(mod.declPtr(val.castTag(.decl_ref_mut).?.data.decl_index).val, mod, arena, start, end),
+            .comptime_field_ptr => sliceArray(val.castTag(.comptime_field_ptr).?.data.field_val, mod, arena, start, end),
             .elem_ptr => blk: {
                 const elem_ptr = val.castTag(.elem_ptr).?.data;
                 break :blk sliceArray(elem_ptr.array_ptr, mod, arena, start + elem_ptr.index, end + elem_ptr.index);
@@ -4742,6 +4775,14 @@ pub const Value = extern union {
             },
         };
 
+        pub const ComptimeFieldPtr = struct {
+            base: Payload,
+            data: struct {
+                field_val: Value,
+                field_ty: Type,
+            },
+        };
+
         pub const ElemPtr = struct {
             pub const base_tag = Tag.elem_ptr;
 
test/behavior/struct.zig
@@ -1336,3 +1336,25 @@ test "packed struct field access via pointer" {
     try S.doTheTest();
     comptime try S.doTheTest();
 }
+
+test "store to comptime field" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+
+    {
+        const S = struct {
+            comptime a: [2]u32 = [2]u32{ 1, 2 },
+        };
+        var s: S = .{};
+        s.a = [2]u32{ 1, 2 };
+        s.a[0] = 1;
+    }
+    {
+        const T = struct { a: u32, b: u32 };
+        const S = struct {
+            comptime a: T = T{ .a = 1, .b = 2 },
+        };
+        var s: S = .{};
+        s.a = T{ .a = 1, .b = 2 };
+        s.a.a = 1;
+    }
+}
test/cases/compile_errors/invalid_store_to_comptime_field.zig
@@ -0,0 +1,20 @@
+pub export fn entry() void {
+    const S = struct {
+        comptime a: [2]u32 = [2]u32{ 1, 2 },
+    };
+    var s: S = .{};
+    s.a = [2]u32{ 2, 2 };
+}
+pub export fn entry1() void {
+    const T = struct { a: u32, b: u32 };
+    const S = struct {
+        comptime a: T = T{ .a = 1, .b = 2 },
+    };
+    var s: S = .{};
+    s.a = T{ .a = 2, .b = 2 };
+}
+// error
+// backend=stage2,llvm
+//
+// :6:19: error: value stored in comptime field does not match the default value of the field
+// :14:19: error: value stored in comptime field does not match the default value of the field