Commit 1365be5d02

mlugg <mlugg@mlugg.co.uk>
2024-09-15 16:45:20
compiler: provide correct result types to `+=` and `-=`
Resolves: #21341
1 parent 5d7fa55
Changed files (5)
lib/std/zig/AstGen.zig
@@ -3785,8 +3785,26 @@ fn assignOp(
         else => undefined,
     };
     const lhs = try gz.addUnNode(.load, lhs_ptr, infix_node);
-    const lhs_type = try gz.addUnNode(.typeof, lhs, infix_node);
-    const rhs = try expr(gz, scope, .{ .rl = .{ .coerced_ty = lhs_type } }, node_datas[infix_node].rhs);
+
+    const rhs_res_ty = switch (op_inst_tag) {
+        .add,
+        .sub,
+        => try gz.add(.{
+            .tag = .extended,
+            .data = .{ .extended = .{
+                .opcode = .inplace_arith_result_ty,
+                .small = @intFromEnum(@as(Zir.Inst.InplaceOp, switch (op_inst_tag) {
+                    .add => .add_eq,
+                    .sub => .sub_eq,
+                    else => unreachable,
+                })),
+                .operand = @intFromEnum(lhs),
+            } },
+        }),
+        else => try gz.addUnNode(.typeof, lhs, infix_node), // same as LHS type
+    };
+    // Not `coerced_ty` since `add`/etc won't coerce to this type.
+    const rhs = try expr(gz, scope, .{ .rl = .{ .ty = rhs_res_ty } }, node_datas[infix_node].rhs);
 
     switch (op_inst_tag) {
         .add, .sub, .mul, .div, .mod_rem => {
lib/std/zig/Zir.zig
@@ -2086,6 +2086,10 @@ pub const Inst = struct {
         /// `operand` is payload index to `UnNode`.
         /// `small` is unused.
         branch_hint,
+        /// Compute the result type for in-place arithmetic, e.g. `+=`.
+        /// `operand` is `Zir.Inst.Ref` of the loaded LHS (*not* its type).
+        /// `small` is an `Inst.InplaceOp`.
+        inplace_arith_result_ty,
 
         pub const InstData = struct {
             opcode: Extended,
@@ -3188,6 +3192,11 @@ pub const Inst = struct {
         calling_convention_inline,
     };
 
+    pub const InplaceOp = enum(u16) {
+        add_eq,
+        sub_eq,
+    };
+
     /// Trailing:
     /// 0. tag_type: Ref, // if has_tag_type
     /// 1. captures_len: u32, // if has_captures_len
@@ -4032,6 +4041,7 @@ fn findDeclsInner(
                 .field_parent_ptr,
                 .builtin_value,
                 .branch_hint,
+                .inplace_arith_result_ty,
                 => return,
 
                 // `@TypeOf` has a body.
src/print_zir.zig
@@ -620,6 +620,7 @@ const Writer = struct {
             .closure_get => try self.writeClosureGet(stream, extended),
             .field_parent_ptr => try self.writeFieldParentPtr(stream, extended),
             .builtin_value => try self.writeBuiltinValue(stream, extended),
+            .inplace_arith_result_ty => try self.writeInplaceArithResultTy(stream, extended),
         }
     }
 
@@ -2781,6 +2782,12 @@ const Writer = struct {
         try self.writeSrcNode(stream, @bitCast(extended.operand));
     }
 
+    fn writeInplaceArithResultTy(self: *Writer, stream: anytype, extended: Zir.Inst.Extended.InstData) !void {
+        const op: Zir.Inst.InplaceOp = @enumFromInt(extended.small);
+        try self.writeInstRef(stream, @enumFromInt(extended.operand));
+        try stream.print(", {s}))", .{@tagName(op)});
+    }
+
     fn writeInstRef(self: *Writer, stream: anytype, ref: Zir.Inst.Ref) !void {
         if (ref == .none) {
             return stream.writeAll(".none");
src/Sema.zig
@@ -1361,6 +1361,7 @@ fn analyzeBodyInner(
                     .value_placeholder => unreachable, // never appears in a body
                     .field_parent_ptr => try sema.zirFieldParentPtr(block, extended),
                     .builtin_value => try sema.zirBuiltinValue(extended),
+                    .inplace_arith_result_ty => try sema.zirInplaceArithResultTy(extended),
                 };
             },
 
@@ -27342,6 +27343,33 @@ fn zirBuiltinValue(sema: *Sema, extended: Zir.Inst.Extended.InstData) CompileErr
     return Air.internedToRef(ty.toIntern());
 }
 
+fn zirInplaceArithResultTy(sema: *Sema, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {
+    const pt = sema.pt;
+    const zcu = pt.zcu;
+
+    const lhs = try sema.resolveInst(@enumFromInt(extended.operand));
+    const lhs_ty = sema.typeOf(lhs);
+
+    const op: Zir.Inst.InplaceOp = @enumFromInt(extended.small);
+    const ty: Type = switch (op) {
+        .add_eq => ty: {
+            const ptr_size = lhs_ty.ptrSizeOrNull(zcu) orelse break :ty lhs_ty;
+            switch (ptr_size) {
+                .One, .Slice => break :ty lhs_ty, // invalid, let it error
+                .Many, .C => break :ty .usize, // `[*]T + usize`
+            }
+        },
+        .sub_eq => ty: {
+            const ptr_size = lhs_ty.ptrSizeOrNull(zcu) orelse break :ty lhs_ty;
+            switch (ptr_size) {
+                .One, .Slice => break :ty lhs_ty, // invalid, let it error
+                .Many, .C => break :ty .generic_poison, // could be `[*]T - [*]T` or `[*]T - usize`
+            }
+        },
+    };
+    return Air.internedToRef(ty.toIntern());
+}
+
 fn zirBranchHint(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!void {
     const pt = sema.pt;
     const zcu = pt.zcu;
test/behavior/pointers.zig
@@ -98,6 +98,21 @@ test "pointer subtraction" {
     }
 }
 
+test "pointer arithmetic with non-trivial RHS" {
+    var t: bool = undefined;
+    t = true;
+
+    var ptr: [*]const u8 = "Hello, World!";
+    ptr += if (t) 5 else 2;
+    try expect(ptr[0] == ',');
+    ptr += if (!t) 4 else 2;
+    try expect(ptr[0] == 'W');
+    ptr -= if (t) @as(usize, 6) else 3;
+    try expect(ptr[0] == 'e');
+    ptr -= if (!t) @as(usize, 0) else 1;
+    try expect(ptr[0] == 'H');
+}
+
 test "double pointer parsing" {
     comptime assert(PtrOf(PtrOf(i32)) == **i32);
 }