Commit c96f9a017a

mlugg <mlugg@mlugg.co.uk>
2024-10-09 00:37:16
Sema: implement @splat for arrays
Resolves: #20433
1 parent 072e062
lib/std/zig/AstGen.zig
@@ -2716,7 +2716,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
             .array_type_sentinel,
             .elem_type,
             .indexable_ptr_elem_type,
-            .vector_elem_type,
+            .vec_arr_elem_type,
             .vector_type,
             .indexable_ptr_len,
             .anyframe_type,
@@ -9529,7 +9529,7 @@ fn builtinCall(
 
         .splat => {
             const result_type = try ri.rl.resultTypeForCast(gz, node, builtin_name);
-            const elem_type = try gz.addUnNode(.vector_elem_type, result_type, node);
+            const elem_type = try gz.addUnNode(.vec_arr_elem_type, result_type, node);
             const scalar = try expr(gz, scope, .{ .rl = .{ .ty = elem_type } }, params[0]);
             const result = try gz.addPlNode(.splat, node, Zir.Inst.Bin{
                 .lhs = result_type,
lib/std/zig/Zir.zig
@@ -247,9 +247,9 @@ pub const Inst = struct {
         /// element type. Emits a compile error if the type is not an indexable pointer.
         /// Uses the `un_node` field.
         indexable_ptr_elem_type,
-        /// Given a vector type, returns its element type.
+        /// Given a vector or array type, returns its element type.
         /// Uses the `un_node` field.
-        vector_elem_type,
+        vec_arr_elem_type,
         /// Given a pointer to an indexable object, returns the len property. This is
         /// used by for loops. This instruction also emits a for-loop specific compile
         /// error if the indexable object is not indexable.
@@ -1065,7 +1065,7 @@ pub const Inst = struct {
                 .vector_type,
                 .elem_type,
                 .indexable_ptr_elem_type,
-                .vector_elem_type,
+                .vec_arr_elem_type,
                 .indexable_ptr_len,
                 .anyframe_type,
                 .as_node,
@@ -1375,7 +1375,7 @@ pub const Inst = struct {
                 .vector_type,
                 .elem_type,
                 .indexable_ptr_elem_type,
-                .vector_elem_type,
+                .vec_arr_elem_type,
                 .indexable_ptr_len,
                 .anyframe_type,
                 .as_node,
@@ -1607,7 +1607,7 @@ pub const Inst = struct {
                 .vector_type = .pl_node,
                 .elem_type = .un_node,
                 .indexable_ptr_elem_type = .un_node,
-                .vector_elem_type = .un_node,
+                .vec_arr_elem_type = .un_node,
                 .indexable_ptr_len = .un_node,
                 .anyframe_type = .un_node,
                 .as_node = .pl_node,
@@ -3781,7 +3781,7 @@ fn findDeclsInner(
         .vector_type,
         .elem_type,
         .indexable_ptr_elem_type,
-        .vector_elem_type,
+        .vec_arr_elem_type,
         .indexable_ptr_len,
         .anyframe_type,
         .as_node,
src/print_zir.zig
@@ -203,7 +203,7 @@ const Writer = struct {
             .alloc_comptime_mut,
             .elem_type,
             .indexable_ptr_elem_type,
-            .vector_elem_type,
+            .vec_arr_elem_type,
             .indexable_ptr_len,
             .anyframe_type,
             .bit_not,
src/Sema.zig
@@ -1087,7 +1087,7 @@ fn analyzeBodyInner(
             .elem_val_imm                 => try sema.zirElemValImm(block, inst),
             .elem_type                    => try sema.zirElemType(block, inst),
             .indexable_ptr_elem_type      => try sema.zirIndexablePtrElemType(block, inst),
-            .vector_elem_type             => try sema.zirVectorElemType(block, inst),
+            .vec_arr_elem_type            => try sema.zirVecArrElemType(block, inst),
             .enum_literal                 => try sema.zirEnumLiteral(block, inst),
             .decl_literal                 => try sema.zirDeclLiteral(block, inst, true),
             .decl_literal_no_coerce       => try sema.zirDeclLiteral(block, inst, false),
@@ -2046,7 +2046,7 @@ fn genericPoisonReason(sema: *Sema, block: *Block, ref: Zir.Inst.Ref) GenericPoi
                 const bin = sema.code.instructions.items(.data)[@intFromEnum(inst)].bin;
                 cur = bin.lhs;
             },
-            .indexable_ptr_elem_type, .vector_elem_type => {
+            .indexable_ptr_elem_type, .vec_arr_elem_type => {
                 const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
                 cur = un_node.operand;
             },
@@ -8603,7 +8603,7 @@ fn zirIndexablePtrElemType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Com
     return Air.internedToRef(elem_ty.toIntern());
 }
 
-fn zirVectorElemType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
+fn zirVecArrElemType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
     const pt = sema.pt;
     const zcu = pt.zcu;
     const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
@@ -8615,8 +8615,9 @@ fn zirVectorElemType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileEr
         error.GenericPoison => return .generic_poison_type,
         else => |e| return e,
     };
-    if (!vec_ty.isVector(zcu)) {
-        return sema.fail(block, block.nodeOffset(un_node.src_node), "expected vector type, found '{}'", .{vec_ty.fmt(pt)});
+    switch (vec_ty.zigTypeTag(zcu)) {
+        .array, .vector => {},
+        else => return sema.fail(block, block.nodeOffset(un_node.src_node), "expected array or vector type, found '{}'", .{vec_ty.fmt(pt)}),
     }
     return Air.internedToRef(vec_ty.childType(zcu).toIntern());
 }
@@ -24804,26 +24805,66 @@ fn zirSplat(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.I
     const scalar_src = block.builtinCallArgSrc(inst_data.src_node, 0);
     const dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_eu_opt, "@splat");
 
-    if (!dest_ty.isVector(zcu)) return sema.fail(block, src, "expected vector type, found '{}'", .{dest_ty.fmt(pt)});
+    switch (dest_ty.zigTypeTag(zcu)) {
+        .array, .vector => {},
+        else => return sema.fail(block, src, "expected array or vector type, found '{}'", .{dest_ty.fmt(pt)}),
+    }
 
-    if (!dest_ty.hasRuntimeBits(zcu)) {
+    const operand = try sema.resolveInst(extra.rhs);
+    const scalar_ty = dest_ty.childType(zcu);
+    const scalar = try sema.coerce(block, scalar_ty, operand, scalar_src);
+
+    const len = try sema.usizeCast(block, src, dest_ty.arrayLen(zcu));
+
+    // `len == 0` because `[0:s]T` always has a comptime-known splat.
+    if (!dest_ty.hasRuntimeBits(zcu) or len == 0) {
         const empty_aggregate = try pt.intern(.{ .aggregate = .{
             .ty = dest_ty.toIntern(),
-            .storage = .{ .elems = &[_]InternPool.Index{} },
+            .storage = .{ .elems = &.{} },
         } });
         return Air.internedToRef(empty_aggregate);
     }
 
-    const operand = try sema.resolveInst(extra.rhs);
-    const scalar_ty = dest_ty.childType(zcu);
-    const scalar = try sema.coerce(block, scalar_ty, operand, scalar_src);
+    const maybe_sentinel = dest_ty.sentinel(zcu);
+
     if (try sema.resolveValue(scalar)) |scalar_val| {
-        if (scalar_val.isUndef(zcu)) return pt.undefRef(dest_ty);
-        return Air.internedToRef((try sema.splat(dest_ty, scalar_val)).toIntern());
+        if (scalar_val.isUndef(zcu) and maybe_sentinel == null) {
+            return pt.undefRef(dest_ty);
+        }
+        // TODO: I didn't want to put `.aggregate` on a separate line here; `zig fmt` bugs have forced my hand
+        return Air.internedToRef(try pt.intern(.{
+            .aggregate = .{
+                .ty = dest_ty.toIntern(),
+                .storage = s: {
+                    full: {
+                        if (dest_ty.zigTypeTag(zcu) == .vector) break :full;
+                        const sentinel = maybe_sentinel orelse break :full;
+                        if (sentinel.toIntern() == scalar_val.toIntern()) break :full;
+                        // This is a array with non-zero length and a sentinel which does not match the element.
+                        // We have to use the full `elems` representation.
+                        const elems = try sema.arena.alloc(InternPool.Index, len + 1);
+                        @memset(elems[0..len], scalar_val.toIntern());
+                        elems[len] = sentinel.toIntern();
+                        break :s .{ .elems = elems };
+                    }
+                    break :s .{ .repeated_elem = scalar_val.toIntern() };
+                },
+            },
+        }));
     }
 
     try sema.requireRuntimeBlock(block, src, scalar_src);
-    return block.addTyOp(.splat, dest_ty, scalar);
+
+    switch (dest_ty.zigTypeTag(zcu)) {
+        .array => {
+            const elems = try sema.arena.alloc(Air.Inst.Ref, len + @intFromBool(maybe_sentinel != null));
+            @memset(elems[0..len], scalar);
+            if (maybe_sentinel) |s| elems[len] = Air.internedToRef(s.toIntern());
+            return block.addAggregateInit(dest_ty, elems);
+        },
+        .vector => return block.addTyOp(.splat, dest_ty, scalar),
+        else => unreachable,
+    }
 }
 
 fn zirReduce(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
test/behavior/array.zig
@@ -1021,3 +1021,70 @@ test "runtime index of array of zero-bit values" {
     try std.testing.expect(result.index == 0);
     try std.testing.expect(result.value == {});
 }
+
+test "@splat array" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    const S = struct {
+        fn doTheTest(comptime T: type, x: T) !void {
+            const arr: [10]T = @splat(x);
+            for (arr) |elem| {
+                try expectEqual(x, elem);
+            }
+        }
+    };
+
+    try S.doTheTest(u32, 123);
+    try comptime S.doTheTest(u32, 123);
+
+    const Foo = struct { x: u8 };
+    try S.doTheTest(Foo, .{ .x = 10 });
+    try comptime S.doTheTest(Foo, .{ .x = 10 });
+}
+
+test "@splat array with sentinel" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    const S = struct {
+        fn doTheTest(comptime T: type, x: T, comptime s: T) !void {
+            const arr: [10:s]T = @splat(x);
+            for (arr) |elem| {
+                try expectEqual(x, elem);
+            }
+            const ptr: [*]const T = &arr;
+            try expectEqual(s, ptr[10]); // sentinel correct
+        }
+    };
+
+    try S.doTheTest(u32, 100, 42);
+    try comptime S.doTheTest(u32, 100, 42);
+
+    try S.doTheTest(?*anyopaque, @ptrFromInt(0x1000), null);
+    try comptime S.doTheTest(?*anyopaque, @ptrFromInt(0x1000), null);
+}
+
+test "@splat zero-length array" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    const S = struct {
+        fn doTheTest(comptime T: type, comptime s: T) !void {
+            var runtime_undef: T = undefined;
+            runtime_undef = undefined;
+            // The array should be comptime-known despite the `@splat` operand being runtime-known.
+            const arr: [0:s]T = @splat(runtime_undef);
+            const ptr: [*]const T = &arr;
+            comptime assert(ptr[0] == s);
+        }
+    };
+
+    try S.doTheTest(u32, 42);
+    try comptime S.doTheTest(u32, 42);
+
+    try S.doTheTest(?*anyopaque, null);
+    try comptime S.doTheTest(?*anyopaque, null);
+}
test/cases/compile_errors/splat_bad_result_type.zig
@@ -0,0 +1,7 @@
+export fn f() void {
+    _ = @as(u32, @splat(5));
+}
+
+// error
+//
+// :2:18: error: expected array or vector type, found 'u32'
test/cases/compile_errors/splat_result_type_non_vector.zig
@@ -1,9 +0,0 @@
-export fn f() void {
-    _ = @as(u32, @splat(5));
-}
-
-// error
-// backend=stage2
-// target=native
-//
-// :2:18: error: expected vector type, found 'u32'