Commit d99bed1b10

mlugg <mlugg@mlugg.co.uk>
2023-11-07 08:13:25
Sema: optimize runtime array_mul
There are two optimizations here, which work together to avoid a pathological case. The first optimization is that AstGen now records the result type of an array multiplication expression where possible. This type is not used according to the language specification, but instead as an optimization. In the expression '.{x} ** 1000', if we know that the result must be an array, then it is much more efficient to coerce the LHS to an array with length 1 before doing the multiplication. Otherwise, we end up with a 1000-element tuple which we must coerce to an array by individually extracting each field. Secondly, the previous logic would repeatedly extract element/field values from the LHS when initializing the result. This is unnecessary: each element must only be extracted once, and the result reused. These changes together give huge improvements to compiler performance on a pathological case: AIR instructions go from 65551 to 15, and total AIR bytes go from 1.86MiB to 264.57KiB. Codegen time spent on this function (in a debug compiler build) goes from minutes to essentially zero. Resolves: #17586
1 parent a1d688b
src/AstGen.zig
@@ -758,7 +758,11 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE
         .array_cat        => return simpleBinOp(gz, scope, ri, node, .array_cat),
 
         .array_mult => {
-            const result = try gz.addPlNode(.array_mul, node, Zir.Inst.Bin{
+            // This syntax form does not currently use the result type in the language specification.
+            // However, the result type can be used to emit more optimal code for large multiplications by
+            // having Sema perform a coercion before the multiplication operation.
+            const result = try gz.addPlNode(.array_mul, node, Zir.Inst.ArrayMul{
+                .res_ty = if (try ri.rl.resultType(gz, node)) |t| t else .none,
                 .lhs = try expr(gz, scope, .{ .rl = .none }, node_datas[node].lhs),
                 .rhs = try comptimeExpr(gz, scope, .{ .rl = .{ .coerced_ty = .usize_type } }, node_datas[node].rhs),
             });
src/Autodoc.zig
@@ -1567,7 +1567,6 @@ fn walkInstruction(
         .bit_and,
         .xor,
         .array_cat,
-        .array_mul,
         => {
             const pl_node = data[@intFromEnum(inst)].pl_node;
             const extra = file.zir.extraData(Zir.Inst.Bin, pl_node.payload_index);
src/print_zir.zig
@@ -370,7 +370,6 @@ const Writer = struct {
             .add_sat,
             .add_unsafe,
             .array_cat,
-            .array_mul,
             .mul,
             .mulwrap,
             .mul_sat,
@@ -431,6 +430,8 @@ const Writer = struct {
 
             .for_len => try self.writePlNodeMultiOp(stream, inst),
 
+            .array_mul => try self.writeArrayMul(stream, inst),
+
             .elem_val_imm => try self.writeElemValImm(stream, inst),
 
             .@"export" => try self.writePlNodeExport(stream, inst),
@@ -977,6 +978,18 @@ const Writer = struct {
         try self.writeSrc(stream, inst_data.src());
     }
 
+    fn writeArrayMul(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
+        const inst_data = self.code.instructions.items(.data)[@intFromEnum(inst)].pl_node;
+        const extra = self.code.extraData(Zir.Inst.ArrayMul, inst_data.payload_index).data;
+        try self.writeInstRef(stream, extra.res_ty);
+        try stream.writeAll(", ");
+        try self.writeInstRef(stream, extra.lhs);
+        try stream.writeAll(", ");
+        try self.writeInstRef(stream, extra.rhs);
+        try stream.writeAll(") ");
+        try self.writeSrc(stream, inst_data.src());
+    }
+
     fn writeElemValImm(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
         const inst_data = self.code.instructions.items(.data)[@intFromEnum(inst)].elem_val_imm;
         try self.writeInstRef(stream, inst_data.operand);
src/Sema.zig
@@ -13998,14 +13998,49 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
 
     const mod = sema.mod;
     const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].pl_node;
-    const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
-    const lhs = try sema.resolveInst(extra.lhs);
-    const lhs_ty = sema.typeOf(lhs);
+    const extra = sema.code.extraData(Zir.Inst.ArrayMul, inst_data.payload_index).data;
+    const uncoerced_lhs = try sema.resolveInst(extra.lhs);
+    const uncoerced_lhs_ty = sema.typeOf(uncoerced_lhs);
     const src: LazySrcLoc = inst_data.src();
     const lhs_src: LazySrcLoc = .{ .node_offset_bin_lhs = inst_data.src_node };
     const operator_src: LazySrcLoc = .{ .node_offset_main_token = inst_data.src_node };
     const rhs_src: LazySrcLoc = .{ .node_offset_bin_rhs = inst_data.src_node };
 
+    const lhs, const lhs_ty = coerced_lhs: {
+        // If we have a result type, we might be able to do this more efficiently
+        // by coercing the LHS first. Specifically, if we want an array or vector
+        // and have a tuple, coerce the tuple immediately.
+        no_coerce: {
+            if (extra.res_ty == .none) break :no_coerce;
+            const res_ty_inst = try sema.resolveInst(extra.res_ty);
+            const res_ty = try sema.analyzeAsType(block, src, res_ty_inst);
+            if (res_ty.isGenericPoison()) break :no_coerce;
+            if (!uncoerced_lhs_ty.isTuple(mod)) break :no_coerce;
+            const lhs_len = uncoerced_lhs_ty.structFieldCount(mod);
+            const lhs_dest_ty = switch (res_ty.zigTypeTag(mod)) {
+                else => break :no_coerce,
+                .Array => try mod.arrayType(.{
+                    .child = res_ty.childType(mod).toIntern(),
+                    .len = lhs_len,
+                    .sentinel = if (res_ty.sentinel(mod)) |s| s.toIntern() else .none,
+                }),
+                .Vector => try mod.vectorType(.{
+                    .child = res_ty.childType(mod).toIntern(),
+                    .len = lhs_len,
+                }),
+            };
+            // Attempt to coerce to this type, but don't emit an error if it fails. Instead,
+            // just exit out of this path and let the usual error happen later, so that error
+            // messages are consistent.
+            const coerced = sema.coerceExtra(block, lhs_dest_ty, uncoerced_lhs, lhs_src, .{ .report_err = false }) catch |err| switch (err) {
+                error.NotCoercible => break :no_coerce,
+                else => |e| return e,
+            };
+            break :coerced_lhs .{ coerced, lhs_dest_ty };
+        }
+        break :coerced_lhs .{ uncoerced_lhs, uncoerced_lhs_ty };
+    };
+
     if (lhs_ty.isTuple(mod)) {
         // In `**` rhs must be comptime-known, but lhs can be runtime-known
         const factor = try sema.resolveInt(block, rhs_src, extra.rhs, Type.usize, .{
@@ -14086,6 +14121,14 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
 
     try sema.requireRuntimeBlock(block, src, lhs_src);
 
+    // Grab all the LHS values ahead of time, rather than repeatedly emitting instructions
+    // to get the same elem values.
+    const lhs_vals = try sema.arena.alloc(Air.Inst.Ref, lhs_len);
+    for (lhs_vals, 0..) |*lhs_val, idx| {
+        const idx_ref = try mod.intRef(Type.usize, idx);
+        lhs_val.* = try sema.elemVal(block, lhs_src, lhs, idx_ref, src, false);
+    }
+
     if (ptr_addrspace) |ptr_as| {
         const alloc_ty = try sema.ptrType(.{
             .child = result_ty.toIntern(),
@@ -14099,14 +14142,11 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
 
         var elem_i: usize = 0;
         while (elem_i < result_len) {
-            var lhs_i: usize = 0;
-            while (lhs_i < lhs_len) : (lhs_i += 1) {
+            for (lhs_vals) |lhs_val| {
                 const elem_index = try mod.intRef(Type.usize, elem_i);
-                elem_i += 1;
-                const lhs_index = try mod.intRef(Type.usize, lhs_i);
                 const elem_ptr = try block.addPtrElemPtr(alloc, elem_index, elem_ptr_ty);
-                const init = try sema.elemVal(block, lhs_src, lhs, lhs_index, src, true);
-                try sema.storePtr2(block, src, elem_ptr, src, init, lhs_src, .store);
+                try sema.storePtr2(block, src, elem_ptr, src, lhs_val, lhs_src, .store);
+                elem_i += 1;
             }
         }
         if (lhs_info.sentinel) |sent_val| {
@@ -14120,17 +14160,9 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
     }
 
     const element_refs = try sema.arena.alloc(Air.Inst.Ref, result_len);
-    var elem_i: usize = 0;
-    while (elem_i < result_len) {
-        var lhs_i: usize = 0;
-        while (lhs_i < lhs_len) : (lhs_i += 1) {
-            const lhs_index = try mod.intRef(Type.usize, lhs_i);
-            const init = try sema.elemVal(block, lhs_src, lhs, lhs_index, src, true);
-            element_refs[elem_i] = init;
-            elem_i += 1;
-        }
+    for (0..try sema.usizeCast(block, rhs_src, factor)) |i| {
+        @memcpy(element_refs[i * lhs_len ..][0..lhs_len], lhs_vals);
     }
-
     return block.addAggregateInit(result_ty, element_refs);
 }
 
src/Zir.zig
@@ -250,7 +250,7 @@ pub const Inst = struct {
         /// Uses the `pl_node` union field. Payload is `Bin`.
         array_cat,
         /// Array multiplication `a ** b`
-        /// Uses the `pl_node` union field. Payload is `Bin`.
+        /// Uses the `pl_node` union field. Payload is `ArrayMul`.
         array_mul,
         /// `[N]T` syntax. No source location provided.
         /// Uses the `pl_node` union field. Payload is `Bin`. lhs is length, rhs is element type.
@@ -3373,6 +3373,15 @@ pub const Inst = struct {
         /// The expected field count.
         expect_len: u32,
     };
+
+    pub const ArrayMul = struct {
+        /// The result type of the array multiplication operation, or `.none` if none was available.
+        res_ty: Ref,
+        /// The LHS of the array multiplication.
+        lhs: Ref,
+        /// The RHS of the array multiplication.
+        rhs: Ref,
+    };
 };
 
 pub const SpecialProng = enum { none, @"else", under };