Commit b7e2235973

Jakub Konka <kubkon@jakubkonka.com>
2021-12-29 20:55:05
stage2: lower 1-byte and 2-byte values saved to stack
* fix handling of `ah`, `bh`, `ch`, and `dh` registers (which are actually used as aliases to `dil`, etc. registers). Currenly, we treat them as aliases only meaning when we encounter `ah` we make sure to set the REX.W to promote the instruction to 64bits and use `dil` register instead - otherwise we might have mismatch between registers used in different parts of the codegen. In the future, we can and should use `ah`, etc. as upper 8bit halves of 16bit registers `ax`, etc. * fix bug in `airCmp` where `.cmp` MIR instruction shouldn't force type `Bool` but let the type of the original type propagate downwards - we need this to make an informed choice of the target register size and hence choose the right encoding down the line. * implement lowering of 1-byte and 2-byte values to stack and add matching stage2 tests for x86_64 codegen
1 parent 08ea1a2
Changed files (3)
src
test
stage2
src/arch/x86_64/CodeGen.zig
@@ -1642,11 +1642,11 @@ fn genBinMathOpMir(
                     });
                 },
                 .immediate => |imm| {
-                    // TODO I am not quite sure why we need to set the size of the register here...
+                    const abi_size = dst_ty.abiSize(self.target.*);
                     _ = try self.addInst(.{
                         .tag = mir_tag,
                         .ops = (Mir.Ops{
-                            .reg1 = dst_reg.to32(),
+                            .reg1 = registerAlias(dst_reg, @intCast(u32, abi_size)),
                         }).encode(),
                         .data = .{ .imm = @intCast(i32, imm) },
                     });
@@ -1751,13 +1751,14 @@ fn genIMulOpMir(self: *Self, dst_ty: Type, dst_mcv: MCValue, src_mcv: MCValue) !
                     });
                 },
                 .immediate => |imm| {
+                    // TODO take into account the type's ABI size when selecting the register alias
                     // register, immediate
                     if (imm <= math.maxInt(i32)) {
                         _ = try self.addInst(.{
                             .tag = .imul_complex,
                             .ops = (Mir.Ops{
-                                .reg1 = dst_reg,
-                                .reg2 = dst_reg,
+                                .reg1 = dst_reg.to32(),
+                                .reg2 = dst_reg.to32(),
                                 .flags = 0b10,
                             }).encode(),
                             .data = .{ .imm = @intCast(i32, imm) },
@@ -2147,7 +2148,7 @@ fn airCmp(self: *Self, inst: Air.Inst.Index, op: math.CompareOperator) !void {
         // This instruction supports only signed 32-bit immediates at most.
         const src_mcv = try self.limitImmediateType(bin_op.rhs, i32);
 
-        try self.genBinMathOpMir(.cmp, Type.initTag(.bool), dst_mcv, src_mcv);
+        try self.genBinMathOpMir(.cmp, ty, dst_mcv, src_mcv);
         break :result switch (ty.isSignedInt()) {
             true => MCValue{ .compare_flags_signed = op },
             false => MCValue{ .compare_flags_unsigned = op },
@@ -2792,16 +2793,10 @@ fn genSetStack(self: *Self, ty: Type, stack_offset: u32, mcv: MCValue) InnerErro
                 return self.fail("TODO implement set stack variable with large stack offset", .{});
             }
             switch (abi_size) {
-                1 => {
-                    return self.fail("TODO implement set abi_size=1 stack variable with immediate", .{});
-                },
-                2 => {
-                    return self.fail("TODO implement set abi_size=2 stack variable with immediate", .{});
-                },
-                4 => {
+                1, 2, 4 => {
                     // We have a positive stack offset value but we want a twos complement negative
                     // offset from rbp, which is at the top of the stack frame.
-                    // mov    DWORD PTR [rbp+offset], immediate
+                    // mov [rbp+offset], immediate
                     const payload = try self.addExtra(Mir.ImmPair{
                         .dest_off = -@intCast(i32, adj_off),
                         .operand = @bitCast(i32, @intCast(u32, x_big)),
@@ -2810,7 +2805,12 @@ fn genSetStack(self: *Self, ty: Type, stack_offset: u32, mcv: MCValue) InnerErro
                         .tag = .mov_mem_imm,
                         .ops = (Mir.Ops{
                             .reg1 = .rbp,
-                            .flags = 0b10,
+                            .flags = switch (abi_size) {
+                                1 => 0b00,
+                                2 => 0b01,
+                                4 => 0b10,
+                                else => unreachable,
+                            },
                         }).encode(),
                         .data = .{ .payload = payload },
                     });
@@ -2954,11 +2954,12 @@ fn genSetReg(self: *Self, ty: Type, reg: Register, mcv: MCValue) InnerError!void
                 return;
             }
             if (x <= math.maxInt(i32)) {
+                const abi_size = ty.abiSize(self.target.*);
                 // Next best case: if we set the lower four bytes, the upper four will be zeroed.
                 _ = try self.addInst(.{
                     .tag = .mov,
                     .ops = (Mir.Ops{
-                        .reg1 = reg.to32(),
+                        .reg1 = registerAlias(reg, @intCast(u32, abi_size)),
                     }).encode(),
                     .data = .{ .imm = @intCast(i32, x) },
                 });
src/arch/x86_64/Emit.zig
@@ -468,7 +468,15 @@ fn mirArithMemImm(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
     ) catch |err| emit.failWithLoweringError(err);
 }
 
-fn immOpSize(imm: i64) u8 {
+inline fn setRexWRegister(reg: Register) bool {
+    if (reg.size() == 64) return true;
+    return switch (reg) {
+        .ah, .bh, .ch, .dh => true,
+        else => false,
+    };
+}
+
+inline fn immOpSize(imm: i64) u8 {
     blk: {
         _ = math.cast(i8, imm) catch break :blk;
         return 8;
@@ -1370,7 +1378,10 @@ fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8))
                 encoder.opcode_1byte(0x66);
             }
             encoder.rex(.{
-                .w = tag.isSetCC(),
+                .w = switch (reg) {
+                    .ah, .bh, .ch, .dh => true,
+                    else => false,
+                },
                 .b = reg.isExtended(),
             });
             opc.encode(encoder);
@@ -1389,7 +1400,7 @@ fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8))
                     return error.OperandSizeMismatch;
                 }
                 encoder.rex(.{
-                    .w = tag.isSetCC(),
+                    .w = false,
                     .b = reg.isExtended(),
                 });
                 opc.encode(encoder);
@@ -1455,7 +1466,7 @@ fn lowerToTdFdEnc(tag: Tag, reg: Register, moffs: i64, code: *std.ArrayList(u8),
         encoder.opcode_1byte(0x66);
     }
     encoder.rex(.{
-        .w = reg.size() == 64,
+        .w = setRexWRegister(reg),
     });
     opc.encode(encoder);
     switch (reg.size()) {
@@ -1488,7 +1499,7 @@ fn lowerToOiEnc(tag: Tag, reg: Register, imm: i64, code: *std.ArrayList(u8)) Low
         encoder.opcode_1byte(0x66);
     }
     encoder.rex(.{
-        .w = reg.size() == 64,
+        .w = setRexWRegister(reg),
         .b = reg.isExtended(),
     });
     opc.encodeWithReg(encoder, reg);
@@ -1525,7 +1536,7 @@ fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.Arr
                 encoder.opcode_1byte(0x66);
             }
             encoder.rex(.{
-                .w = dst_reg.size() == 64,
+                .w = setRexWRegister(dst_reg),
                 .b = dst_reg.isExtended(),
             });
             opc.encode(encoder);
@@ -1623,7 +1634,7 @@ fn lowerToRmEnc(
             }
             const encoder = try Encoder.init(code, 3);
             encoder.rex(.{
-                .w = reg.size() == 64,
+                .w = setRexWRegister(reg) or setRexWRegister(src_reg),
                 .r = reg.isExtended(),
                 .b = src_reg.isExtended(),
             });
@@ -1645,7 +1656,7 @@ fn lowerToRmEnc(
                     return error.OperandSizeMismatch;
                 }
                 encoder.rex(.{
-                    .w = reg.size() == 64,
+                    .w = setRexWRegister(reg),
                     .r = reg.isExtended(),
                     .b = src_reg.isExtended(),
                 });
@@ -1676,7 +1687,7 @@ fn lowerToRmEnc(
                 }
             } else {
                 encoder.rex(.{
-                    .w = reg.size() == 64,
+                    .w = setRexWRegister(reg),
                     .r = reg.isExtended(),
                 });
                 opc.encode(encoder);
@@ -1706,7 +1717,7 @@ fn lowerToMrEnc(
             }
             const encoder = try Encoder.init(code, 3);
             encoder.rex(.{
-                .w = dst_reg.size() == 64,
+                .w = setRexWRegister(dst_reg) or setRexWRegister(reg),
                 .r = reg.isExtended(),
                 .b = dst_reg.isExtended(),
             });
@@ -1726,7 +1737,7 @@ fn lowerToMrEnc(
                     return error.OperandSizeMismatch;
                 }
                 encoder.rex(.{
-                    .w = dst_mem.ptr_size == .qword_ptr,
+                    .w = dst_mem.ptr_size == .qword_ptr or setRexWRegister(reg),
                     .r = reg.isExtended(),
                     .b = dst_reg.isExtended(),
                 });
@@ -1757,7 +1768,7 @@ fn lowerToMrEnc(
                 }
             } else {
                 encoder.rex(.{
-                    .w = dst_mem.ptr_size == .qword_ptr,
+                    .w = dst_mem.ptr_size == .qword_ptr or setRexWRegister(reg),
                     .r = reg.isExtended(),
                 });
                 opc.encode(encoder);
@@ -1794,7 +1805,7 @@ fn lowerToRmiEnc(
                 return error.OperandSizeMismatch;
             }
             encoder.rex(.{
-                .w = reg.size() == 64,
+                .w = setRexWRegister(reg) or setRexWRegister(src_reg),
                 .r = reg.isExtended(),
                 .b = src_reg.isExtended(),
             });
@@ -1812,7 +1823,7 @@ fn lowerToRmiEnc(
                     return error.OperandSizeMismatch;
                 }
                 encoder.rex(.{
-                    .w = reg.size() == 64,
+                    .w = setRexWRegister(reg),
                     .r = reg.isExtended(),
                     .b = src_reg.isExtended(),
                 });
@@ -1843,7 +1854,7 @@ fn lowerToRmiEnc(
                 }
             } else {
                 encoder.rex(.{
-                    .w = reg.size() == 64,
+                    .w = setRexWRegister(reg),
                     .r = reg.isExtended(),
                 });
                 opc.encode(encoder);
@@ -2089,7 +2100,7 @@ test "lower M encoding" {
     try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(null, 0x10, .qword_ptr), code.buffer());
     try expectEqualHexStrings("\xFF\x24\x25\x10\x00\x00\x00", code.emitted(), "jmp qword ptr [ds:0x10]");
     try lowerToMEnc(.seta, RegisterOrMemory.reg(.r11b), code.buffer());
-    try expectEqualHexStrings("\x49\x0F\x97\xC3", code.emitted(), "seta r11b");
+    try expectEqualHexStrings("\x41\x0F\x97\xC3", code.emitted(), "seta r11b");
 }
 
 test "lower O encoding" {
@@ -2111,9 +2122,9 @@ test "lower RMI encoding" {
         "imul rax, qword ptr [rbp - 8], 0x10",
     );
     try lowerToRmiEnc(.imul, .eax, RegisterOrMemory.mem(.rbp, -4, .dword_ptr), 0x10, code.buffer());
-    try expectEqualHexStrings("\x69\x45\xFC\x10\x00\x00\x00", code.emitted(), "imul ax, [rbp - 2], 0x10");
+    try expectEqualHexStrings("\x69\x45\xFC\x10\x00\x00\x00", code.emitted(), "imul eax, dword ptr [rbp - 4], 0x10");
     try lowerToRmiEnc(.imul, .ax, RegisterOrMemory.mem(.rbp, -2, .word_ptr), 0x10, code.buffer());
-    try expectEqualHexStrings("\x66\x69\x45\xFE\x10\x00", code.emitted(), "imul eax, [rbp - 4], 0x10");
+    try expectEqualHexStrings("\x66\x69\x45\xFE\x10\x00", code.emitted(), "imul ax, word ptr [rbp - 2], 0x10");
     try lowerToRmiEnc(.imul, .r12, RegisterOrMemory.reg(.r12), 0x10, code.buffer());
     try expectEqualHexStrings("\x4D\x69\xE4\x10\x00\x00\x00", code.emitted(), "imul r12, r12, 0x10");
     try lowerToRmiEnc(.imul, .r12w, RegisterOrMemory.reg(.r12w), 0x10, code.buffer());
test/stage2/x86_64.zig
@@ -1604,6 +1604,64 @@ pub fn addCases(ctx: *TestContext) !void {
                 ":2:28: error: cannot set address space of local variable 'foo'",
             });
         }
+
+        {
+            var case = ctx.exe("saving vars of different ABI size to stack", target);
+
+            case.addCompareOutput(
+                \\pub fn main() void {
+                \\    assert(callMe(2) == 24);
+                \\}
+                \\
+                \\fn callMe(a: u8) u8 {
+                \\    var b: u8 = a + 10;
+                \\    const c = 2 * b;
+                \\    return c;
+                \\}
+                \\
+                \\pub fn assert(ok: bool) void {
+                \\    if (!ok) unreachable; // assertion failure
+                \\}
+            ,
+                "",
+            );
+
+            case.addCompareOutput(
+                \\pub fn main() void {
+                \\    assert(callMe(2) == 24);
+                \\}
+                \\
+                \\fn callMe(a: u16) u16 {
+                \\    var b: u16 = a + 10;
+                \\    const c = 2 * b;
+                \\    return c;
+                \\}
+                \\
+                \\pub fn assert(ok: bool) void {
+                \\    if (!ok) unreachable; // assertion failure
+                \\}
+            ,
+                "",
+            );
+
+            case.addCompareOutput(
+                \\pub fn main() void {
+                \\    assert(callMe(2) == 24);
+                \\}
+                \\
+                \\fn callMe(a: u32) u32 {
+                \\    var b: u32 = a + 10;
+                \\    const c = 2 * b;
+                \\    return c;
+                \\}
+                \\
+                \\pub fn assert(ok: bool) void {
+                \\    if (!ok) unreachable; // assertion failure
+                \\}
+            ,
+                "",
+            );
+        }
     }
 }