Commit 888708ec8a

Wooster <wooster0@proton.me>
2024-07-15 20:18:38
Sema: support pointer subtraction
1 parent 89942eb
doc/langref/test_pointer_arithmetic.zig
@@ -11,6 +11,9 @@ test "pointer arithmetic with many-item pointer" {
     // slicing a many-item pointer without an end is equivalent to
     // pointer arithmetic: `ptr[start..] == ptr + start`
     try expect(ptr[1..] == ptr + 1);
+
+    // subtraction between any two pointers except slices based on element size is supported
+    try expect(&ptr[1] - &ptr[0] == 1);
 }
 
 test "pointer arithmetic with slices" {
doc/langref.html.in
@@ -1935,16 +1935,18 @@ or
           <li>{#syntax#}*T{#endsyntax#} - single-item pointer to exactly one item.
             <ul>
               <li>Supports deref syntax: {#syntax#}ptr.*{#endsyntax#}</li>
+              <li>Supports pointer subtraction: {#syntax#}ptr - ptr{#endsyntax#}</li>
             </ul>
           </li>
           <li>{#syntax#}[*]T{#endsyntax#} - many-item pointer to unknown number of items.
             <ul>
               <li>Supports index syntax: {#syntax#}ptr[i]{#endsyntax#}</li>
               <li>Supports slice syntax: {#syntax#}ptr[start..end]{#endsyntax#} and {#syntax#}ptr[start..]{#endsyntax#}</li>
-              <li>Supports pointer arithmetic: {#syntax#}ptr + x{#endsyntax#}, {#syntax#}ptr - x{#endsyntax#}</li>
-              <li>{#syntax#}T{#endsyntax#} must have a known size, which means that it cannot be
-              {#syntax#}anyopaque{#endsyntax#} or any other {#link|opaque type|opaque#}.</li>
+              <li>Supports pointer-integer arithmetic: {#syntax#}ptr + int{#endsyntax#}, {#syntax#}ptr - int{#endsyntax#}</li>
+              <li>Supports pointer subtraction: {#syntax#}ptr - ptr{#endsyntax#}</li>
             </ul>
+            {#syntax#}T{#endsyntax#} must have a known size, which means that it cannot be
+            {#syntax#}anyopaque{#endsyntax#} or any other {#link|opaque type|opaque#}.
           </li>
       </ul>
       <p>These types are closely related to {#link|Arrays#} and {#link|Slices#}:</p>
@@ -1954,6 +1956,7 @@ or
                 <li>Supports index syntax: {#syntax#}array_ptr[i]{#endsyntax#}</li>
                 <li>Supports slice syntax: {#syntax#}array_ptr[start..end]{#endsyntax#}</li>
                 <li>Supports len property: {#syntax#}array_ptr.len{#endsyntax#}</li>
+                <li>Supports pointer subtraction: {#syntax#}array_ptr - array_ptr{#endsyntax#}</li>
             </ul>
             </li>
         </ul>
src/InternPool.zig
@@ -1931,6 +1931,23 @@ pub const Key = union(enum) {
                 /// the original pointer type alignment must be used.
                 orig_ty: Index,
             };
+
+            pub fn eql(a: BaseAddr, b: BaseAddr) bool {
+                if (@as(Key.Ptr.BaseAddr.Tag, a) != @as(Key.Ptr.BaseAddr.Tag, b)) return false;
+
+                return switch (a) {
+                    .decl => |a_decl| a_decl == b.decl,
+                    .comptime_alloc => |a_alloc| a_alloc == b.comptime_alloc,
+                    .anon_decl => |ad| ad.val == b.anon_decl.val and
+                        ad.orig_ty == b.anon_decl.orig_ty,
+                    .int => true,
+                    .eu_payload => |a_eu_payload| a_eu_payload == b.eu_payload,
+                    .opt_payload => |a_opt_payload| a_opt_payload == b.opt_payload,
+                    .comptime_field => |a_comptime_field| a_comptime_field == b.comptime_field,
+                    .arr_elem => |a_elem| std.meta.eql(a_elem, b.arr_elem),
+                    .field => |a_field| std.meta.eql(a_field, b.field),
+                };
+            }
         };
     };
 
@@ -2369,21 +2386,8 @@ pub const Key = union(enum) {
                 const b_info = b.ptr;
                 if (a_info.ty != b_info.ty) return false;
                 if (a_info.byte_offset != b_info.byte_offset) return false;
-
-                if (@as(Key.Ptr.BaseAddr.Tag, a_info.base_addr) != @as(Key.Ptr.BaseAddr.Tag, b_info.base_addr)) return false;
-
-                return switch (a_info.base_addr) {
-                    .decl => |a_decl| a_decl == b_info.base_addr.decl,
-                    .comptime_alloc => |a_alloc| a_alloc == b_info.base_addr.comptime_alloc,
-                    .anon_decl => |ad| ad.val == b_info.base_addr.anon_decl.val and
-                        ad.orig_ty == b_info.base_addr.anon_decl.orig_ty,
-                    .int => true,
-                    .eu_payload => |a_eu_payload| a_eu_payload == b_info.base_addr.eu_payload,
-                    .opt_payload => |a_opt_payload| a_opt_payload == b_info.base_addr.opt_payload,
-                    .comptime_field => |a_comptime_field| a_comptime_field == b_info.base_addr.comptime_field,
-                    .arr_elem => |a_elem| std.meta.eql(a_elem, b_info.base_addr.arr_elem),
-                    .field => |a_field| std.meta.eql(a_field, b_info.base_addr.field),
-                };
+                if (!a_info.base_addr.eql(b_info.base_addr)) return false;
+                return true;
             },
 
             .int => |a_info| {
src/Sema.zig
@@ -2429,6 +2429,16 @@ fn failWithComptimeErrorRetTrace(
     return sema.failWithOwnedErrorMsg(block, msg);
 }
 
+fn failWithInvalidPtrArithmetic(sema: *Sema, block: *Block, src: LazySrcLoc, arithmetic: []const u8, supports: []const u8) CompileError {
+    const msg = msg: {
+        const msg = try sema.errMsg(src, "invalid {s} arithmetic operator", .{arithmetic});
+        errdefer msg.destroy(sema.gpa);
+        try sema.errNote(src, msg, "{s} arithmetic only supports {s}", .{ arithmetic, supports });
+        break :msg msg;
+    };
+    return sema.failWithOwnedErrorMsg(block, msg);
+}
+
 /// We don't return a pointer to the new error note because the pointer
 /// becomes invalid when you add another one.
 pub fn errNote(
@@ -15146,7 +15156,7 @@ fn zirDiv(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Ins
     const lhs_zig_ty_tag = try lhs_ty.zigTypeTagOrPoison(mod);
     const rhs_zig_ty_tag = try rhs_ty.zigTypeTagOrPoison(mod);
     try sema.checkVectorizableBinaryOperands(block, src, lhs_ty, rhs_ty, lhs_src, rhs_src);
-    try sema.checkInvalidPtrArithmetic(block, src, lhs_ty);
+    try sema.checkInvalidPtrIntArithmetic(block, src, lhs_ty);
 
     const instructions = &[_]Air.Inst.Ref{ lhs, rhs };
     const resolved_type = try sema.resolvePeerTypes(block, src, instructions, .{
@@ -15312,7 +15322,7 @@ fn zirDivExact(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
     const lhs_zig_ty_tag = try lhs_ty.zigTypeTagOrPoison(mod);
     const rhs_zig_ty_tag = try rhs_ty.zigTypeTagOrPoison(mod);
     try sema.checkVectorizableBinaryOperands(block, src, lhs_ty, rhs_ty, lhs_src, rhs_src);
-    try sema.checkInvalidPtrArithmetic(block, src, lhs_ty);
+    try sema.checkInvalidPtrIntArithmetic(block, src, lhs_ty);
 
     const instructions = &[_]Air.Inst.Ref{ lhs, rhs };
     const resolved_type = try sema.resolvePeerTypes(block, src, instructions, .{
@@ -15478,7 +15488,7 @@ fn zirDivFloor(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
     const lhs_zig_ty_tag = try lhs_ty.zigTypeTagOrPoison(mod);
     const rhs_zig_ty_tag = try rhs_ty.zigTypeTagOrPoison(mod);
     try sema.checkVectorizableBinaryOperands(block, src, lhs_ty, rhs_ty, lhs_src, rhs_src);
-    try sema.checkInvalidPtrArithmetic(block, src, lhs_ty);
+    try sema.checkInvalidPtrIntArithmetic(block, src, lhs_ty);
 
     const instructions = &[_]Air.Inst.Ref{ lhs, rhs };
     const resolved_type = try sema.resolvePeerTypes(block, src, instructions, .{
@@ -15589,7 +15599,7 @@ fn zirDivTrunc(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
     const lhs_zig_ty_tag = try lhs_ty.zigTypeTagOrPoison(mod);
     const rhs_zig_ty_tag = try rhs_ty.zigTypeTagOrPoison(mod);
     try sema.checkVectorizableBinaryOperands(block, src, lhs_ty, rhs_ty, lhs_src, rhs_src);
-    try sema.checkInvalidPtrArithmetic(block, src, lhs_ty);
+    try sema.checkInvalidPtrIntArithmetic(block, src, lhs_ty);
 
     const instructions = &[_]Air.Inst.Ref{ lhs, rhs };
     const resolved_type = try sema.resolvePeerTypes(block, src, instructions, .{
@@ -15833,7 +15843,7 @@ fn zirModRem(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.
     const lhs_zig_ty_tag = try lhs_ty.zigTypeTagOrPoison(mod);
     const rhs_zig_ty_tag = try rhs_ty.zigTypeTagOrPoison(mod);
     try sema.checkVectorizableBinaryOperands(block, src, lhs_ty, rhs_ty, lhs_src, rhs_src);
-    try sema.checkInvalidPtrArithmetic(block, src, lhs_ty);
+    try sema.checkInvalidPtrIntArithmetic(block, src, lhs_ty);
 
     const instructions = &[_]Air.Inst.Ref{ lhs, rhs };
     const resolved_type = try sema.resolvePeerTypes(block, src, instructions, .{
@@ -16019,7 +16029,7 @@ fn zirMod(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Ins
     const lhs_zig_ty_tag = try lhs_ty.zigTypeTagOrPoison(mod);
     const rhs_zig_ty_tag = try rhs_ty.zigTypeTagOrPoison(mod);
     try sema.checkVectorizableBinaryOperands(block, src, lhs_ty, rhs_ty, lhs_src, rhs_src);
-    try sema.checkInvalidPtrArithmetic(block, src, lhs_ty);
+    try sema.checkInvalidPtrIntArithmetic(block, src, lhs_ty);
 
     const instructions = &[_]Air.Inst.Ref{ lhs, rhs };
     const resolved_type = try sema.resolvePeerTypes(block, src, instructions, .{
@@ -16115,7 +16125,7 @@ fn zirRem(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Ins
     const lhs_zig_ty_tag = try lhs_ty.zigTypeTagOrPoison(mod);
     const rhs_zig_ty_tag = try rhs_ty.zigTypeTagOrPoison(mod);
     try sema.checkVectorizableBinaryOperands(block, src, lhs_ty, rhs_ty, lhs_src, rhs_src);
-    try sema.checkInvalidPtrArithmetic(block, src, lhs_ty);
+    try sema.checkInvalidPtrIntArithmetic(block, src, lhs_ty);
 
     const instructions = &[_]Air.Inst.Ref{ lhs, rhs };
     const resolved_type = try sema.resolvePeerTypes(block, src, instructions, .{
@@ -16458,17 +16468,78 @@ fn analyzeArithmetic(
     const rhs_zig_ty_tag = try rhs_ty.zigTypeTagOrPoison(mod);
     try sema.checkVectorizableBinaryOperands(block, src, lhs_ty, rhs_ty, lhs_src, rhs_src);
 
-    if (lhs_zig_ty_tag == .Pointer) switch (lhs_ty.ptrSize(mod)) {
-        .One, .Slice => {},
-        .Many, .C => {
-            const air_tag: Air.Inst.Tag = switch (zir_tag) {
-                .add => .ptr_add,
-                .sub => .ptr_sub,
-                else => return sema.fail(block, src, "invalid pointer arithmetic operator", .{}),
-            };
-            return sema.analyzePtrArithmetic(block, src, lhs, rhs, air_tag, lhs_src, rhs_src);
-        },
-    };
+    if (lhs_zig_ty_tag == .Pointer) {
+        if (rhs_zig_ty_tag == .Pointer) {
+            if (lhs_ty.ptrSize(mod) != .Slice and rhs_ty.ptrSize(mod) != .Slice) {
+                if (zir_tag != .sub) {
+                    return sema.failWithInvalidPtrArithmetic(block, src, "pointer-pointer", "subtraction");
+                }
+                if (!lhs_ty.elemType2(mod).eql(rhs_ty.elemType2(mod), mod)) {
+                    return sema.fail(block, src, "incompatible pointer arithmetic operands '{}' and '{}'", .{
+                        lhs_ty.fmt(pt), rhs_ty.fmt(pt),
+                    });
+                }
+
+                const elem_size = lhs_ty.elemType2(mod).abiSize(pt);
+                if (elem_size == 0) {
+                    return sema.fail(block, src, "pointer arithmetic requires element type '{}' to have runtime bits", .{
+                        lhs_ty.elemType2(mod).fmt(pt),
+                    });
+                }
+
+                const runtime_src = runtime_src: {
+                    if (try sema.resolveValue(lhs)) |lhs_value| {
+                        if (try sema.resolveValue(rhs)) |rhs_value| {
+                            const lhs_ptr = switch (mod.intern_pool.indexToKey(lhs_value.toIntern())) {
+                                .undef => return sema.failWithUseOfUndef(block, lhs_src),
+                                .ptr => |ptr| ptr,
+                                else => unreachable,
+                            };
+                            const rhs_ptr = switch (mod.intern_pool.indexToKey(rhs_value.toIntern())) {
+                                .undef => return sema.failWithUseOfUndef(block, rhs_src),
+                                .ptr => |ptr| ptr,
+                                else => unreachable,
+                            };
+                            // Make sure the pointers point to the same data.
+                            if (!lhs_ptr.base_addr.eql(rhs_ptr.base_addr)) break :runtime_src src;
+                            const address = std.math.sub(u64, lhs_ptr.byte_offset, rhs_ptr.byte_offset) catch
+                                return sema.fail(block, src, "operation results in overflow", .{});
+                            const result = address / elem_size;
+                            return try pt.intRef(Type.usize, result);
+                        } else {
+                            break :runtime_src lhs_src;
+                        }
+                    } else {
+                        break :runtime_src rhs_src;
+                    }
+                };
+
+                try sema.requireRuntimeBlock(block, src, runtime_src);
+                const lhs_int = try block.addUnOp(.int_from_ptr, lhs);
+                const rhs_int = try block.addUnOp(.int_from_ptr, rhs);
+                const address = try block.addBinOp(.sub_wrap, lhs_int, rhs_int);
+                return try block.addBinOp(.div_exact, address, try pt.intRef(Type.usize, elem_size));
+            }
+        } else {
+            switch (lhs_ty.ptrSize(mod)) {
+                .One, .Slice => {},
+                .Many, .C => {
+                    const air_tag: Air.Inst.Tag = switch (zir_tag) {
+                        .add => .ptr_add,
+                        .sub => .ptr_sub,
+                        else => return sema.failWithInvalidPtrArithmetic(block, src, "pointer-integer", "addition and subtraction"),
+                    };
+
+                    if (!try sema.typeHasRuntimeBits(lhs_ty.elemType2(mod))) {
+                        return sema.fail(block, src, "pointer arithmetic requires element type '{}' to have runtime bits", .{
+                            lhs_ty.elemType2(mod).fmt(pt),
+                        });
+                    }
+                    return sema.analyzePtrArithmetic(block, src, lhs, rhs, air_tag, lhs_src, rhs_src);
+                },
+            }
+        }
+    }
 
     const instructions = &[_]Air.Inst.Ref{ lhs, rhs };
     const resolved_type = try sema.resolvePeerTypes(block, src, instructions, .{
@@ -23762,7 +23833,7 @@ fn checkIntType(sema: *Sema, block: *Block, src: LazySrcLoc, ty: Type) CompileEr
     }
 }
 
-fn checkInvalidPtrArithmetic(
+fn checkInvalidPtrIntArithmetic(
     sema: *Sema,
     block: *Block,
     src: LazySrcLoc,
@@ -23773,12 +23844,7 @@ fn checkInvalidPtrArithmetic(
     switch (try ty.zigTypeTagOrPoison(mod)) {
         .Pointer => switch (ty.ptrSize(mod)) {
             .One, .Slice => return,
-            .Many, .C => return sema.fail(
-                block,
-                src,
-                "invalid pointer arithmetic operator",
-                .{},
-            ),
+            .Many, .C => return sema.failWithInvalidPtrArithmetic(block, src, "pointer-integer", "addition and subtraction"),
         },
         else => return,
     }
src/Value.zig
@@ -3752,7 +3752,7 @@ pub fn ptrField(parent_ptr: Value, field_idx: u32, pt: Zcu.PerThread) !Value {
     const parent_ptr_info = parent_ptr_ty.ptrInfo(zcu);
     assert(parent_ptr_info.flags.size == .One);
 
-    // Exiting this `switch` indicates that the `field` pointer repsentation should be used.
+    // Exiting this `switch` indicates that the `field` pointer representation should be used.
     // `field_align` may be `.none` to represent the natural alignment of `field_ty`, but is not necessarily.
     const field_ty: Type, const field_align: InternPool.Alignment = switch (aggregate_ty.zigTypeTag(zcu)) {
         .Struct => field: {
test/behavior/pointers.zig
@@ -17,7 +17,7 @@ fn testDerefPtr() !void {
     try expect(x == 1235);
 }
 
-test "pointer arithmetic" {
+test "pointer-integer arithmetic" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
@@ -43,6 +43,62 @@ test "pointer arithmetic" {
     try expect(ptr[0] == 'a');
 }
 
+test "pointer subtraction" {
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+
+    {
+        const a: *u8 = @ptrFromInt(100);
+        const b: *u8 = @ptrFromInt(50);
+        try expect(a - b == 50);
+    }
+    {
+        var ptr: [*]const u8 = "abc";
+        try expect(&ptr[1] - &ptr[0] == 1);
+        try expect(&ptr[2] - &ptr[0] == 2);
+    }
+    {
+        const a: *[100]u16 = @ptrFromInt(100);
+        const b: *[100]u16 = @ptrFromInt(50);
+        try expect(a - b == 25);
+    }
+    {
+        var x: struct { a: u32, b: u32 } = undefined;
+        const a = &x.a;
+        const b = &x.b;
+        try expect(a - a == 0);
+        try expect(b - b == 0);
+        try expect(b - a == 1);
+    }
+    comptime {
+        var x: packed struct { a: u1, b: u1 } = undefined;
+        const a = &x.a;
+        const b = &x.b;
+        try expect(a - a == 0);
+        try expect(b - b == 0);
+        try expect(b - a == 0);
+    }
+    comptime {
+        var x: extern struct { a: u32, b: u32 } = undefined;
+        const a = &x.a;
+        const b = &x.b;
+        try expect(a - a == 0);
+        try expect(b - b == 0);
+        try expect(b - a == 1);
+    }
+    comptime {
+        const a: *const [3]u8 = "abc";
+        const b: [*]const u8 = @ptrCast(a);
+        try expect(&a[1] - &b[0] == 1);
+    }
+    comptime {
+        var x: [64][64]u8 = undefined;
+        const a = &x[0][12];
+        const b = &x[15][3];
+        try expect(b - a == 951);
+    }
+}
+
 test "double pointer parsing" {
     comptime assert(PtrOf(PtrOf(i32)) == **i32);
 }
@@ -382,7 +438,7 @@ test "pointer to array at fixed address" {
     try expect(@intFromPtr(&array[1]) == 0x14);
 }
 
-test "pointer arithmetic affects the alignment" {
+test "pointer-integer arithmetic affects the alignment" {
     {
         var ptr: [*]align(8) u32 = undefined;
         var x: usize = 1;
test/cases/compile_errors/invalid_pointer_arithmetic.zig
@@ -0,0 +1,52 @@
+export fn a(x: [*]u8) void {
+    _ = x * 1;
+}
+
+export fn b(x: *u8) void {
+    _ = x * x;
+}
+
+export fn c() void {
+    const x: []u8 = undefined;
+    const y: []u8 = undefined;
+    _ = x - y;
+}
+
+export fn d() void {
+    var x: [*]u8 = undefined;
+    var y: [*]u16 = undefined;
+    _ = &x;
+    _ = &y;
+    _ = x - y;
+}
+
+comptime {
+    const x: *u8 = @ptrFromInt(1);
+    const y: *u16 = @ptrFromInt(2);
+    _ = x - y;
+}
+
+comptime {
+    const x: [*]u0 = @ptrFromInt(1);
+    _ = x + 1;
+}
+
+comptime {
+    const x: *u0 = @ptrFromInt(1);
+    const y: *u0 = @ptrFromInt(2);
+    _ = x - y;
+}
+
+// error
+// backend=stage2
+// target=native
+//
+// :2:11: error: invalid pointer-integer arithmetic operator
+// :2:11: note: pointer-integer arithmetic only supports addition and subtraction
+// :6:11: error: invalid pointer-pointer arithmetic operator
+// :6:11: note: pointer-pointer arithmetic only supports subtraction
+// :12:11: error: invalid operands to binary expression: 'Pointer' and 'Pointer'
+// :20:11: error: incompatible pointer arithmetic operands '[*]u8' and '[*]u16'
+// :26:11: error: incompatible pointer arithmetic operands '*u8' and '*u16'
+// :31:11: error: pointer arithmetic requires element type 'u0' to have runtime bits
+// :37:11: error: pointer arithmetic requires element type 'u0' to have runtime bits