Commit ce88c43a4e

mlugg <mlugg@mlugg.co.uk>
2023-06-07 11:20:34
Sema: allow indexing tuple and vector pointers
Resolves: #13852 Resolves: #14705
1 parent 610b02c
Changed files (5)
src/Sema.zig
@@ -9991,7 +9991,7 @@ fn zirElemPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air
                 indexable_ty.fmt(mod),
             });
             errdefer msg.destroy(sema.gpa);
-            if (indexable_ty.zigTypeTag(mod) == .Array) {
+            if (indexable_ty.isIndexable(mod)) {
                 try sema.errNote(block, src, msg, "consider using '&' here", .{});
             }
             break :msg msg;
@@ -26088,8 +26088,19 @@ fn elemPtrOneLayerOnly(
             return block.addPtrElemPtr(indexable, elem_index, result_ty);
         },
         .One => {
-            assert(indexable_ty.childType(mod).zigTypeTag(mod) == .Array); // Guaranteed by checkIndexable
-            return sema.elemPtrArray(block, src, indexable_src, indexable, elem_index_src, elem_index, init, oob_safety);
+            const child_ty = indexable_ty.childType(mod);
+            switch (child_ty.zigTypeTag(mod)) {
+                .Array, .Vector => {
+                    return sema.elemPtrArray(block, src, indexable_src, indexable, elem_index_src, elem_index, init, oob_safety);
+                },
+                .Struct => {
+                    assert(child_ty.isTuple(mod));
+                    const index_val = try sema.resolveConstValue(block, elem_index_src, elem_index, "tuple field access index must be comptime-known");
+                    const index = @intCast(u32, index_val.toUnsignedInt(mod));
+                    return sema.tupleFieldPtr(block, indexable_src, indexable, elem_index_src, index, false);
+                },
+                else => unreachable, // Guaranteed by checkIndexable
+            }
         },
     }
 }
@@ -26139,19 +26150,15 @@ fn elemVal(
                 return block.addBinOp(.ptr_elem_val, indexable, elem_index);
             },
             .One => {
-                const array_ty = indexable_ty.childType(mod); // Guaranteed by checkIndexable
-                assert(array_ty.zigTypeTag(mod) == .Array);
-
-                if (array_ty.sentinel(mod)) |sentinel| {
-                    // index must be defined since it can access out of bounds
-                    if (try sema.resolveDefinedValue(block, elem_index_src, elem_index)) |index_val| {
-                        const index = @intCast(usize, index_val.toUnsignedInt(mod));
-                        if (index == array_ty.arrayLen(mod)) {
-                            return sema.addConstant(array_ty.childType(mod), sentinel);
-                        }
-                    }
+                arr_sent: {
+                    const inner_ty = indexable_ty.childType(mod);
+                    if (inner_ty.zigTypeTag(mod) != .Array) break :arr_sent;
+                    const sentinel = inner_ty.sentinel(mod) orelse break :arr_sent;
+                    const index_val = try sema.resolveDefinedValue(block, elem_index_src, elem_index) orelse break :arr_sent;
+                    const index = try sema.usizeCast(block, src, index_val.toUnsignedInt(mod));
+                    if (index != inner_ty.arrayLen(mod)) break :arr_sent;
+                    return sema.addConstant(inner_ty.childType(mod), sentinel);
                 }
-
                 const elem_ptr = try sema.elemPtr(block, indexable_src, indexable, elem_index, elem_index_src, false, oob_safety);
                 return sema.analyzeLoad(block, indexable_src, elem_ptr, elem_index_src);
             },
src/type.zig
@@ -2832,7 +2832,11 @@ pub const Type = struct {
             .Array, .Vector => true,
             .Pointer => switch (ty.ptrSize(mod)) {
                 .Slice, .Many, .C => true,
-                .One => ty.childType(mod).zigTypeTag(mod) == .Array,
+                .One => switch (ty.childType(mod).zigTypeTag(mod)) {
+                    .Array, .Vector => true,
+                    .Struct => ty.childType(mod).isTuple(mod),
+                    else => false,
+                },
             },
             .Struct => ty.isTuple(mod),
             else => false,
@@ -2845,7 +2849,11 @@ pub const Type = struct {
             .Pointer => switch (ty.ptrSize(mod)) {
                 .Many, .C => false,
                 .Slice => true,
-                .One => ty.childType(mod).zigTypeTag(mod) == .Array,
+                .One => switch (ty.childType(mod).zigTypeTag(mod)) {
+                    .Array, .Vector => true,
+                    .Struct => ty.childType(mod).isTuple(mod),
+                    else => false,
+                },
             },
             .Struct => ty.isTuple(mod),
             else => false,
test/behavior/for.zig
@@ -463,3 +463,19 @@ test "inline for with counter as the comptime-known" {
 
     try expect(S.ok == 2);
 }
+
+test "inline for on tuple pointer" {
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
+
+    const S = struct { u32, u32, u32 };
+    var s: S = .{ 100, 200, 300 };
+
+    inline for (&s, 0..) |*x, i| {
+        x.* = i;
+    }
+
+    try expectEqual(S{ 0, 1, 2 }, s);
+}
test/behavior/tuple.zig
@@ -1,6 +1,7 @@
 const builtin = @import("builtin");
 const std = @import("std");
 const testing = std.testing;
+const assert = std.debug.assert;
 const expect = testing.expect;
 const expectEqualStrings = std.testing.expectEqualStrings;
 const expectEqual = std.testing.expectEqual;
@@ -428,3 +429,27 @@ test "sentinel slice in tuple" {
 
     _ = S;
 }
+
+test "tuple pointer is indexable" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
+
+    const S = struct { u32, bool };
+
+    const x: S = .{ 123, true };
+    comptime assert(@TypeOf(&(&x)[0]) == *const u32); // validate constness
+    try expectEqual(@as(u32, 123), (&x)[0]);
+    try expectEqual(true, (&x)[1]);
+
+    var y: S = .{ 123, true };
+    comptime assert(@TypeOf(&(&y)[0]) == *u32); // validate constness
+    try expectEqual(@as(u32, 123), (&y)[0]);
+    try expectEqual(true, (&y)[1]);
+
+    (&y)[0] = 100;
+    (&y)[1] = false;
+    try expectEqual(@as(u32, 100), (&y)[0]);
+    try expectEqual(false, (&y)[1]);
+}
test/behavior/vector.zig
@@ -2,6 +2,7 @@ const std = @import("std");
 const builtin = @import("builtin");
 const mem = std.mem;
 const math = std.math;
+const assert = std.debug.assert;
 const expect = std.testing.expect;
 const expectEqual = std.testing.expectEqual;
 
@@ -1343,3 +1344,29 @@ test "compare vectors with different element types" {
     var b: @Vector(2, u9) = .{ 3, 0 };
     try expectEqual(@Vector(2, bool){ true, false }, a < b);
 }
+
+test "vector pointer is indexable" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
+
+    const V = @Vector(2, u32);
+
+    const x: V = .{ 123, 456 };
+    comptime assert(@TypeOf(&(&x)[0]) == *const u32); // validate constness
+    try expectEqual(@as(u32, 123), (&x)[0]);
+    try expectEqual(@as(u32, 456), (&x)[1]);
+
+    var y: V = .{ 123, 456 };
+    comptime assert(@TypeOf(&(&y)[0]) == *u32); // validate constness
+    try expectEqual(@as(u32, 123), (&y)[0]);
+    try expectEqual(@as(u32, 456), (&y)[1]);
+
+    (&y)[0] = 100;
+    (&y)[1] = 200;
+    try expectEqual(@as(u32, 100), (&y)[0]);
+    try expectEqual(@as(u32, 200), (&y)[1]);
+}