Commit aaa641feba

Jakub Konka <kubkon@jakubkonka.com>
2022-01-18 11:58:22
stage2: add inline memset for x86_64 backend
* introduce new Mir tag `mov_mem_index_imm` which selects instruction of the form `OP ptr [reg + rax*1 + imm32], imm32` where the encoded flags select the appropriate ptr width for memory store operation (note that scale is fixed and set at 1)
1 parent 4938fb8
Changed files (4)
src/arch/x86_64/CodeGen.zig
@@ -3148,7 +3148,7 @@ fn genSetStack(self: *Self, ty: Type, stack_offset: u32, mcv: MCValue) InnerErro
                 2 => return self.genSetStack(ty, stack_offset, .{ .immediate = 0xaaaa }),
                 4 => return self.genSetStack(ty, stack_offset, .{ .immediate = 0xaaaaaaaa }),
                 8 => return self.genSetStack(ty, stack_offset, .{ .immediate = 0xaaaaaaaaaaaaaaaa }),
-                else => return self.fail("TODO implement memset", .{}),
+                else => return self.genInlineMemset(ty, stack_offset, .{ .immediate = 0xaa }),
             }
         },
         .compare_flags_unsigned => |op| {
@@ -3398,6 +3398,97 @@ fn genInlineMemcpy(
     try self.performReloc(loop_reloc);
 }
 
+fn genInlineMemset(self: *Self, ty: Type, stack_offset: u32, value: MCValue) InnerError!void {
+    try self.register_manager.getReg(.rax, null);
+    const abi_size = ty.abiSize(self.target.*);
+    const adj_off = stack_offset + abi_size;
+    if (adj_off > 128) {
+        return self.fail("TODO inline memset with large stack offset", .{});
+    }
+    const negative_offset = @bitCast(u32, -@intCast(i32, adj_off));
+
+    // We are actually counting `abi_size` bytes; however, we reuse the index register
+    // as both the counter and offset scaler, hence we need to subtract one from `abi_size`
+    // and count until -1.
+    if (abi_size > math.maxInt(i32)) {
+        // movabs rax, abi_size - 1
+        const payload = try self.addExtra(Mir.Imm64.encode(abi_size - 1));
+        _ = try self.addInst(.{
+            .tag = .movabs,
+            .ops = (Mir.Ops{
+                .reg1 = .rax,
+            }).encode(),
+            .data = .{ .payload = payload },
+        });
+    } else {
+        // mov rax, abi_size - 1
+        _ = try self.addInst(.{
+            .tag = .mov,
+            .ops = (Mir.Ops{
+                .reg1 = .rax,
+            }).encode(),
+            .data = .{ .imm = @truncate(u32, abi_size - 1) },
+        });
+    }
+
+    // loop:
+    // cmp rax, -1
+    const loop_start = try self.addInst(.{
+        .tag = .cmp,
+        .ops = (Mir.Ops{
+            .reg1 = .rax,
+        }).encode(),
+        .data = .{ .imm = @bitCast(u32, @as(i32, -1)) },
+    });
+
+    // je end
+    const loop_reloc = try self.addInst(.{
+        .tag = .cond_jmp_eq_ne,
+        .ops = (Mir.Ops{ .flags = 0b01 }).encode(),
+        .data = .{ .inst = undefined },
+    });
+
+    switch (value) {
+        .immediate => |x| {
+            if (x > math.maxInt(i32)) {
+                return self.fail("TODO inline memset for value immediate larger than 32bits", .{});
+            }
+            // mov byte ptr [rbp + rax + stack_offset], imm
+            const payload = try self.addExtra(Mir.ImmPair{
+                .dest_off = negative_offset,
+                .operand = @truncate(u32, x),
+            });
+            _ = try self.addInst(.{
+                .tag = .mov_mem_index_imm,
+                .ops = (Mir.Ops{
+                    .reg1 = .rbp,
+                }).encode(),
+                .data = .{ .payload = payload },
+            });
+        },
+        else => return self.fail("TODO inline memset for value of type {}", .{value}),
+    }
+
+    // sub rax, 1
+    _ = try self.addInst(.{
+        .tag = .sub,
+        .ops = (Mir.Ops{
+            .reg1 = .rax,
+        }).encode(),
+        .data = .{ .imm = 1 },
+    });
+
+    // jmp loop
+    _ = try self.addInst(.{
+        .tag = .jmp,
+        .ops = (Mir.Ops{ .flags = 0b00 }).encode(),
+        .data = .{ .inst = loop_start },
+    });
+
+    // end:
+    try self.performReloc(loop_reloc);
+}
+
 fn genSetReg(self: *Self, ty: Type, reg: Register, mcv: MCValue) InnerError!void {
     switch (mcv) {
         .dead => unreachable,
@@ -3639,7 +3730,7 @@ fn airArrayToSlice(self: *Self, inst: Air.Inst.Index) !void {
         const stack_offset = try self.allocMem(inst, 16, 16);
         const array_ty = ptr_ty.childType();
         const array_len = array_ty.arrayLenIncludingSentinel();
-        try self.genSetStack(Type.initTag(.usize), stack_offset + 8, ptr);
+        try self.genSetStack(ptr_ty, stack_offset + 8, ptr);
         try self.genSetStack(Type.initTag(.u64), stack_offset + 16, .{ .immediate = array_len });
         break :blk .{ .stack_offset = stack_offset };
     };
src/arch/x86_64/Emit.zig
@@ -115,6 +115,16 @@ pub fn lowerMir(emit: *Emit) InnerError!void {
             .cmp_scale_imm => try emit.mirArithScaleImm(.cmp, inst),
             .mov_scale_imm => try emit.mirArithScaleImm(.mov, inst),
 
+            .adc_mem_index_imm => try emit.mirArithMemIndexImm(.adc, inst),
+            .add_mem_index_imm => try emit.mirArithMemIndexImm(.add, inst),
+            .sub_mem_index_imm => try emit.mirArithMemIndexImm(.sub, inst),
+            .xor_mem_index_imm => try emit.mirArithMemIndexImm(.xor, inst),
+            .and_mem_index_imm => try emit.mirArithMemIndexImm(.@"and", inst),
+            .or_mem_index_imm => try emit.mirArithMemIndexImm(.@"or", inst),
+            .sbb_mem_index_imm => try emit.mirArithMemIndexImm(.sbb, inst),
+            .cmp_mem_index_imm => try emit.mirArithMemIndexImm(.cmp, inst),
+            .mov_mem_index_imm => try emit.mirArithMemIndexImm(.mov, inst),
+
             .movabs => try emit.mirMovabs(inst),
 
             .lea => try emit.mirLea(inst),
@@ -549,6 +559,29 @@ fn mirArithScaleImm(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
     }), imm_pair.operand, emit.code) catch |err| emit.failWithLoweringError(err);
 }
 
+fn mirArithMemIndexImm(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
+    const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
+    assert(ops.reg2 == .none);
+    const payload = emit.mir.instructions.items(.data)[inst].payload;
+    const imm_pair = emit.mir.extraData(Mir.ImmPair, payload).data;
+    const ptr_size: Memory.PtrSize = switch (ops.flags) {
+        0b00 => .byte_ptr,
+        0b01 => .word_ptr,
+        0b10 => .dword_ptr,
+        0b11 => .qword_ptr,
+    };
+    const scale_index = ScaleIndex{
+        .scale = 0,
+        .index = .rax,
+    };
+    // OP ptr [reg1 + rax*1 + imm32], imm32
+    return lowerToMiEnc(tag, RegisterOrMemory.mem(ptr_size, .{
+        .disp = imm_pair.dest_off,
+        .base = ops.reg1,
+        .scale_index = scale_index,
+    }), imm_pair.operand, emit.code) catch |err| emit.failWithLoweringError(err);
+}
+
 fn mirMovabs(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
     const tag = emit.mir.instructions.items(.tag)[inst];
     assert(tag == .movabs);
src/arch/x86_64/Mir.zig
@@ -79,6 +79,13 @@ pub const Inst = struct {
         ///  * Data field `payload` points at `ImmPair`.
         adc_scale_imm,
 
+        /// ops flags: form:
+        ///       0b00 byte ptr [reg1 + rax + imm32], imm8
+        ///       0b01 word ptr [reg1 + rax + imm32], imm16
+        ///       0b10 dword ptr [reg1 + rax + imm32], imm32
+        ///       0b11 qword ptr [reg1 + rax + imm32], imm32 (sign-extended to imm64)
+        adc_mem_index_imm,
+
         // The following instructions all have the same encoding as `adc`.
 
         add,
@@ -86,81 +93,97 @@ pub const Inst = struct {
         add_scale_src,
         add_scale_dst,
         add_scale_imm,
+        add_mem_index_imm,
         sub,
         sub_mem_imm,
         sub_scale_src,
         sub_scale_dst,
         sub_scale_imm,
+        sub_mem_index_imm,
         xor,
         xor_mem_imm,
         xor_scale_src,
         xor_scale_dst,
         xor_scale_imm,
+        xor_mem_index_imm,
         @"and",
         and_mem_imm,
         and_scale_src,
         and_scale_dst,
         and_scale_imm,
+        and_mem_index_imm,
         @"or",
         or_mem_imm,
         or_scale_src,
         or_scale_dst,
         or_scale_imm,
+        or_mem_index_imm,
         rol,
         rol_mem_imm,
         rol_scale_src,
         rol_scale_dst,
         rol_scale_imm,
+        rol_mem_index_imm,
         ror,
         ror_mem_imm,
         ror_scale_src,
         ror_scale_dst,
         ror_scale_imm,
+        ror_mem_index_imm,
         rcl,
         rcl_mem_imm,
         rcl_scale_src,
         rcl_scale_dst,
         rcl_scale_imm,
+        rcl_mem_index_imm,
         rcr,
         rcr_mem_imm,
         rcr_scale_src,
         rcr_scale_dst,
         rcr_scale_imm,
+        rcr_mem_index_imm,
         shl,
         shl_mem_imm,
         shl_scale_src,
         shl_scale_dst,
         shl_scale_imm,
+        shl_mem_index_imm,
         sal,
         sal_mem_imm,
         sal_scale_src,
         sal_scale_dst,
         sal_scale_imm,
+        sal_mem_index_imm,
         shr,
         shr_mem_imm,
         shr_scale_src,
         shr_scale_dst,
         shr_scale_imm,
+        shr_mem_index_imm,
         sar,
         sar_mem_imm,
         sar_scale_src,
         sar_scale_dst,
         sar_scale_imm,
+        sar_mem_index_imm,
         sbb,
         sbb_mem_imm,
         sbb_scale_src,
         sbb_scale_dst,
         sbb_scale_imm,
+        sbb_mem_index_imm,
         cmp,
         cmp_mem_imm,
         cmp_scale_src,
         cmp_scale_dst,
         cmp_scale_imm,
+        cmp_mem_index_imm,
         mov,
         mov_mem_imm,
         mov_scale_src,
         mov_scale_dst,
         mov_scale_imm,
+        mov_mem_index_imm,
 
         /// ops flags: form:
         ///      0b00  reg1, [reg2 + imm32]
src/arch/x86_64/PrintMir.zig
@@ -64,6 +64,7 @@ pub fn printMir(print: *const Print, w: anytype, mir_to_air_map: std.AutoHashMap
             .@"or" => try print.mirArith(.@"or", inst, w),
             .sbb => try print.mirArith(.sbb, inst, w),
             .cmp => try print.mirArith(.cmp, inst, w),
+            .mov => try print.mirArith(.mov, inst, w),
 
             .adc_mem_imm => try print.mirArithMemImm(.adc, inst, w),
             .add_mem_imm => try print.mirArithMemImm(.add, inst, w),
@@ -73,6 +74,7 @@ pub fn printMir(print: *const Print, w: anytype, mir_to_air_map: std.AutoHashMap
             .or_mem_imm => try print.mirArithMemImm(.@"or", inst, w),
             .sbb_mem_imm => try print.mirArithMemImm(.sbb, inst, w),
             .cmp_mem_imm => try print.mirArithMemImm(.cmp, inst, w),
+            .mov_mem_imm => try print.mirArithMemImm(.mov, inst, w),
 
             .adc_scale_src => try print.mirArithScaleSrc(.adc, inst, w),
             .add_scale_src => try print.mirArithScaleSrc(.add, inst, w),
@@ -82,6 +84,7 @@ pub fn printMir(print: *const Print, w: anytype, mir_to_air_map: std.AutoHashMap
             .or_scale_src => try print.mirArithScaleSrc(.@"or", inst, w),
             .sbb_scale_src => try print.mirArithScaleSrc(.sbb, inst, w),
             .cmp_scale_src => try print.mirArithScaleSrc(.cmp, inst, w),
+            .mov_scale_src => try print.mirArithScaleSrc(.mov, inst, w),
 
             .adc_scale_dst => try print.mirArithScaleDst(.adc, inst, w),
             .add_scale_dst => try print.mirArithScaleDst(.add, inst, w),
@@ -91,6 +94,7 @@ pub fn printMir(print: *const Print, w: anytype, mir_to_air_map: std.AutoHashMap
             .or_scale_dst => try print.mirArithScaleDst(.@"or", inst, w),
             .sbb_scale_dst => try print.mirArithScaleDst(.sbb, inst, w),
             .cmp_scale_dst => try print.mirArithScaleDst(.cmp, inst, w),
+            .mov_scale_dst => try print.mirArithScaleDst(.mov, inst, w),
 
             .adc_scale_imm => try print.mirArithScaleImm(.adc, inst, w),
             .add_scale_imm => try print.mirArithScaleImm(.add, inst, w),
@@ -100,11 +104,18 @@ pub fn printMir(print: *const Print, w: anytype, mir_to_air_map: std.AutoHashMap
             .or_scale_imm => try print.mirArithScaleImm(.@"or", inst, w),
             .sbb_scale_imm => try print.mirArithScaleImm(.sbb, inst, w),
             .cmp_scale_imm => try print.mirArithScaleImm(.cmp, inst, w),
-
-            .mov => try print.mirArith(.mov, inst, w),
-            .mov_scale_src => try print.mirArithScaleSrc(.mov, inst, w),
-            .mov_scale_dst => try print.mirArithScaleDst(.mov, inst, w),
             .mov_scale_imm => try print.mirArithScaleImm(.mov, inst, w),
+
+            .adc_mem_index_imm => try print.mirArithMemIndexImm(.adc, inst, w),
+            .add_mem_index_imm => try print.mirArithMemIndexImm(.add, inst, w),
+            .sub_mem_index_imm => try print.mirArithMemIndexImm(.sub, inst, w),
+            .xor_mem_index_imm => try print.mirArithMemIndexImm(.xor, inst, w),
+            .and_mem_index_imm => try print.mirArithMemIndexImm(.@"and", inst, w),
+            .or_mem_index_imm => try print.mirArithMemIndexImm(.@"or", inst, w),
+            .sbb_mem_index_imm => try print.mirArithMemIndexImm(.sbb, inst, w),
+            .cmp_mem_index_imm => try print.mirArithMemIndexImm(.cmp, inst, w),
+            .mov_mem_index_imm => try print.mirArithMemIndexImm(.mov, inst, w),
+
             .movabs => try print.mirMovabs(inst, w),
 
             .lea => try print.mirLea(inst, w),
@@ -316,11 +327,11 @@ fn mirArithScaleDst(print: *const Print, tag: Mir.Inst.Tag, inst: Mir.Inst.Index
 
     if (ops.reg2 == .none) {
         // OP [reg1 + scale*rax + 0], imm32
-        try w.print("{s} [{s} + {d}*rcx + 0], {d}\n", .{ @tagName(tag), @tagName(ops.reg1), scale, imm });
+        try w.print("{s} [{s} + {d}*rax + 0], {d}\n", .{ @tagName(tag), @tagName(ops.reg1), scale, imm });
     }
 
     // OP [reg1 + scale*rax + imm32], reg2
-    try w.print("{s} [{s} + {d}*rcx + {d}], {s}\n", .{ @tagName(tag), @tagName(ops.reg1), scale, imm, @tagName(ops.reg2) });
+    try w.print("{s} [{s} + {d}*rax + {d}], {s}\n", .{ @tagName(tag), @tagName(ops.reg1), scale, imm, @tagName(ops.reg2) });
 }
 
 fn mirArithScaleImm(print: *const Print, tag: Mir.Inst.Tag, inst: Mir.Inst.Index, w: anytype) !void {
@@ -328,7 +339,21 @@ fn mirArithScaleImm(print: *const Print, tag: Mir.Inst.Tag, inst: Mir.Inst.Index
     const scale = ops.flags;
     const payload = print.mir.instructions.items(.data)[inst].payload;
     const imm_pair = print.mir.extraData(Mir.ImmPair, payload).data;
-    try w.print("{s} [{s} + {d}*rcx + {d}], {d}\n", .{ @tagName(tag), @tagName(ops.reg1), scale, imm_pair.dest_off, imm_pair.operand });
+    try w.print("{s} [{s} + {d}*rax + {d}], {d}\n", .{ @tagName(tag), @tagName(ops.reg1), scale, imm_pair.dest_off, imm_pair.operand });
+}
+
+fn mirArithMemIndexImm(print: *const Print, tag: Mir.Inst.Tag, inst: Mir.Inst.Index, w: anytype) !void {
+    const ops = Mir.Ops.decode(print.mir.instructions.items(.ops)[inst]);
+    const payload = print.mir.instructions.items(.data)[inst].payload;
+    const imm_pair = print.mir.extraData(Mir.ImmPair, payload).data;
+    try w.print("{s} ", .{@tagName(tag)});
+    switch (ops.flags) {
+        0b00 => try w.print("byte ptr ", .{}),
+        0b01 => try w.print("word ptr ", .{}),
+        0b10 => try w.print("dword ptr ", .{}),
+        0b11 => try w.print("qword ptr ", .{}),
+    }
+    try w.print("[{s} + 1*rax + {d}], {d}\n", .{ @tagName(ops.reg1), imm_pair.dest_off, imm_pair.operand });
 }
 
 fn mirMovabs(print: *const Print, inst: Mir.Inst.Index, w: anytype) !void {