Commit d112cd52f3

Jakub Konka <kubkon@jakubkonka.com>
2022-05-04 23:00:41
aarch64: fix mul_with_overflow for ints <= 32bits
1 parent f4421c0
Changed files (4)
src/arch/aarch64/bits.zig
@@ -330,6 +330,17 @@ pub const Instruction = union(enum) {
         op: u1,
         sf: u1,
     },
+    add_subtract_extended_register: packed struct {
+        rd: u5,
+        rn: u5,
+        imm3: u3,
+        option: u3,
+        rm: u5,
+        fixed: u8 = 0b01011_00_1,
+        s: u1,
+        op: u1,
+        sf: u1,
+    },
     conditional_branch: struct {
         cond: u4,
         o0: u1,
@@ -495,6 +506,7 @@ pub const Instruction = union(enum) {
             .logical_immediate => |v| @bitCast(u32, v),
             .bitfield => |v| @bitCast(u32, v),
             .add_subtract_shifted_register => |v| @bitCast(u32, v),
+            .add_subtract_extended_register => |v| @bitCast(u32, v),
             // TODO once packed structs work, this can be refactored
             .conditional_branch => |v| @as(u32, v.cond) | (@as(u32, v.o0) << 4) | (@as(u32, v.imm19) << 5) | (@as(u32, v.o1) << 24) | (@as(u32, v.fixed) << 25),
             .compare_and_branch => |v| @as(u32, v.rt) | (@as(u32, v.imm19) << 5) | (@as(u32, v.op) << 24) | (@as(u32, v.fixed) << 25) | (@as(u32, v.sf) << 31),
@@ -1006,6 +1018,44 @@ pub const Instruction = union(enum) {
         };
     }
 
+    pub const AddSubtractExtendedRegisterOption = enum(u3) {
+        uxtb,
+        uxth,
+        uxtw,
+        uxtx, // serves also as lsl
+        sxtb,
+        sxth,
+        sxtw,
+        sxtx,
+    };
+
+    fn addSubtractExtendedRegister(
+        op: u1,
+        s: u1,
+        rd: Register,
+        rn: Register,
+        rm: Register,
+        extend: AddSubtractExtendedRegisterOption,
+        imm3: u3,
+    ) Instruction {
+        return Instruction{
+            .add_subtract_extended_register = .{
+                .rd = rd.enc(),
+                .rn = rn.enc(),
+                .imm3 = imm3,
+                .option = @enumToInt(extend),
+                .rm = rm.enc(),
+                .s = s,
+                .op = op,
+                .sf = switch (rd.size()) {
+                    32 => 0b0,
+                    64 => 0b1,
+                    else => unreachable, // unexpected register size
+                },
+            },
+        };
+    }
+
     fn conditionalBranch(
         o0: u1,
         o1: u1,
@@ -1524,6 +1574,48 @@ pub const Instruction = union(enum) {
         return addSubtractShiftedRegister(0b1, 0b1, shift, rd, rn, rm, imm6);
     }
 
+    // Add/subtract (extended register)
+
+    pub fn addExtendedRegister(
+        rd: Register,
+        rn: Register,
+        rm: Register,
+        extend: AddSubtractExtendedRegisterOption,
+        imm3: u3,
+    ) Instruction {
+        return addSubtractExtendedRegister(0b0, 0b0, rd, rn, rm, extend, imm3);
+    }
+
+    pub fn addsExtendedRegister(
+        rd: Register,
+        rn: Register,
+        rm: Register,
+        extend: AddSubtractExtendedRegisterOption,
+        imm3: u3,
+    ) Instruction {
+        return addSubtractExtendedRegister(0b0, 0b1, rd, rn, rm, extend, imm3);
+    }
+
+    pub fn subExtendedRegister(
+        rd: Register,
+        rn: Register,
+        rm: Register,
+        extend: AddSubtractExtendedRegisterOption,
+        imm3: u3,
+    ) Instruction {
+        return addSubtractExtendedRegister(0b1, 0b0, rd, rn, rm, extend, imm3);
+    }
+
+    pub fn subsExtendedRegister(
+        rd: Register,
+        rn: Register,
+        rm: Register,
+        extend: AddSubtractExtendedRegisterOption,
+        imm3: u3,
+    ) Instruction {
+        return addSubtractExtendedRegister(0b1, 0b1, rd, rn, rm, extend, imm3);
+    }
+
     // Conditional branch
 
     pub fn bCond(cond: Condition, offset: i21) Instruction {
@@ -1565,11 +1657,12 @@ pub const Instruction = union(enum) {
     }
 
     pub fn smaddl(rd: Register, rn: Register, rm: Register, ra: Register) Instruction {
+        assert(rd.size() == 64 and rn.size() == 32 and rm.size() == 32 and ra.size() == 64);
         return dataProcessing3Source(0b00, 0b001, 0b0, rd, rn, rm, ra);
     }
 
     pub fn umaddl(rd: Register, rn: Register, rm: Register, ra: Register) Instruction {
-        assert(rd.size() == 64);
+        assert(rd.size() == 64 and rn.size() == 32 and rm.size() == 32 and ra.size() == 64);
         return dataProcessing3Source(0b00, 0b101, 0b0, rd, rn, rm, ra);
     }
 
@@ -1837,6 +1930,10 @@ test "serialize instructions" {
             .inst = Instruction.smulh(.x0, .x1, .x2),
             .expected = 0b1_00_11011_0_10_00010_0_11111_00001_00000,
         },
+        .{ // adds x0, x1, x2, sxtx
+            .inst = Instruction.addsExtendedRegister(.x0, .x1, .x2, .sxtx, 0),
+            .expected = 0b1_0_1_01011_00_1_00010_111_000_00001_00000,
+        },
     };
 
     for (testcases) |case| {
src/arch/aarch64/CodeGen.zig
@@ -1294,29 +1294,23 @@ fn binOpRegister(
     };
     defer self.register_manager.unfreezeRegs(&.{rhs_reg});
 
-    const dest_reg: Register = reg: {
-        const dest_reg = switch (mir_tag) {
-            .cmp_shifted_register => undefined, // cmp has no destination register
-            else => if (maybe_inst) |inst| blk: {
-                const bin_op = self.air.instructions.items(.data)[inst].bin_op;
-
-                if (lhs_is_register and self.reuseOperand(inst, bin_op.lhs, 0, lhs)) {
-                    break :blk lhs_reg;
-                } else if (rhs_is_register and self.reuseOperand(inst, bin_op.rhs, 1, rhs)) {
-                    break :blk rhs_reg;
-                } else {
-                    const raw_reg = try self.register_manager.allocReg(inst);
-                    break :blk registerAlias(raw_reg, lhs_ty.abiSize(self.target.*));
-                }
-            } else blk: {
-                const raw_reg = try self.register_manager.allocReg(null);
+    const dest_reg = switch (mir_tag) {
+        .cmp_shifted_register => undefined, // cmp has no destination register
+        else => if (maybe_inst) |inst| blk: {
+            const bin_op = self.air.instructions.items(.data)[inst].bin_op;
+
+            if (lhs_is_register and self.reuseOperand(inst, bin_op.lhs, 0, lhs)) {
+                break :blk lhs_reg;
+            } else if (rhs_is_register and self.reuseOperand(inst, bin_op.rhs, 1, rhs)) {
+                break :blk rhs_reg;
+            } else {
+                const raw_reg = try self.register_manager.allocReg(inst);
                 break :blk registerAlias(raw_reg, lhs_ty.abiSize(self.target.*));
-            },
-        };
-        break :reg switch (mir_tag) {
-            .smull, .umull => dest_reg.to64(),
-            else => dest_reg,
-        };
+            }
+        } else blk: {
+            const raw_reg = try self.register_manager.allocReg(null);
+            break :blk registerAlias(raw_reg, lhs_ty.abiSize(self.target.*));
+        },
     };
 
     if (!lhs_is_register) try self.genSetReg(lhs_ty, lhs_reg, lhs);
@@ -1341,9 +1335,7 @@ fn binOpRegister(
             .shift = .lsl,
         } },
         .mul,
-        .smulh,
         .smull,
-        .umulh,
         .umull,
         .lsl_register,
         .asr_register,
@@ -1932,16 +1924,38 @@ fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) !void {
                     self.register_manager.freezeRegs(&.{truncated_reg});
                     defer self.register_manager.unfreezeRegs(&.{truncated_reg});
 
-                    try self.truncRegister(dest_reg, truncated_reg, int_info.signedness, int_info.bits);
-                    _ = try self.binOp(
-                        .cmp_eq,
-                        null,
-                        dest,
-                        .{ .register = truncated_reg },
-                        Type.usize,
-                        Type.usize,
+                    try self.truncRegister(
+                        dest_reg.to32(),
+                        truncated_reg.to32(),
+                        int_info.signedness,
+                        int_info.bits,
                     );
 
+                    switch (int_info.signedness) {
+                        .signed => {
+                            _ = try self.addInst(.{
+                                .tag = .cmp_extended_register,
+                                .data = .{ .rr_extend_shift = .{
+                                    .rn = dest_reg.to64(),
+                                    .rm = truncated_reg.to32(),
+                                    .ext_type = .sxtw,
+                                    .imm3 = 0,
+                                } },
+                            });
+                        },
+                        .unsigned => {
+                            _ = try self.addInst(.{
+                                .tag = .cmp_extended_register,
+                                .data = .{ .rr_extend_shift = .{
+                                    .rn = dest_reg.to64(),
+                                    .rm = truncated_reg.to32(),
+                                    .ext_type = .uxtw,
+                                    .imm3 = 0,
+                                } },
+                            });
+                        },
+                    }
+
                     try self.genSetStack(lhs_ty, stack_offset, .{ .register = truncated_reg });
                     try self.genSetStack(Type.initTag(.u1), stack_offset - overflow_bit_offset, .{
                         .compare_flags_unsigned = .neq,
src/arch/aarch64/Emit.zig
@@ -114,6 +114,12 @@ pub fn emitMir(
             .sub_shifted_register => try emit.mirAddSubtractShiftedRegister(inst),
             .subs_shifted_register => try emit.mirAddSubtractShiftedRegister(inst),
 
+            .add_extended_register => try emit.mirAddSubtractExtendedRegister(inst),
+            .adds_extended_register => try emit.mirAddSubtractExtendedRegister(inst),
+            .sub_extended_register => try emit.mirAddSubtractExtendedRegister(inst),
+            .subs_extended_register => try emit.mirAddSubtractExtendedRegister(inst),
+            .cmp_extended_register => try emit.mirAddSubtractExtendedRegister(inst),
+
             .cset => try emit.mirConditionalSelect(inst),
 
             .dbg_line => try emit.mirDbgLine(inst),
@@ -732,6 +738,47 @@ fn mirAddSubtractShiftedRegister(emit: *Emit, inst: Mir.Inst.Index) !void {
     }
 }
 
+fn mirAddSubtractExtendedRegister(emit: *Emit, inst: Mir.Inst.Index) !void {
+    const tag = emit.mir.instructions.items(.tag)[inst];
+    switch (tag) {
+        .add_extended_register,
+        .adds_extended_register,
+        .sub_extended_register,
+        .subs_extended_register,
+        => {
+            const rrr_extend_shift = emit.mir.instructions.items(.data)[inst].rrr_extend_shift;
+            const rd = rrr_extend_shift.rd;
+            const rn = rrr_extend_shift.rn;
+            const rm = rrr_extend_shift.rm;
+            const ext_type = rrr_extend_shift.ext_type;
+            const imm3 = rrr_extend_shift.imm3;
+
+            switch (tag) {
+                .add_extended_register => try emit.writeInstruction(Instruction.addExtendedRegister(rd, rn, rm, ext_type, imm3)),
+                .adds_extended_register => try emit.writeInstruction(Instruction.addsExtendedRegister(rd, rn, rm, ext_type, imm3)),
+                .sub_extended_register => try emit.writeInstruction(Instruction.subExtendedRegister(rd, rn, rm, ext_type, imm3)),
+                .subs_extended_register => try emit.writeInstruction(Instruction.subsExtendedRegister(rd, rn, rm, ext_type, imm3)),
+                else => unreachable,
+            }
+        },
+        .cmp_extended_register => {
+            const rr_extend_shift = emit.mir.instructions.items(.data)[inst].rr_extend_shift;
+            const rn = rr_extend_shift.rn;
+            const rm = rr_extend_shift.rm;
+            const ext_type = rr_extend_shift.ext_type;
+            const imm3 = rr_extend_shift.imm3;
+            const zr: Register = switch (rn.size()) {
+                32 => .wzr,
+                64 => .xzr,
+                else => unreachable,
+            };
+
+            try emit.writeInstruction(Instruction.subsExtendedRegister(zr, rn, rm, ext_type, imm3));
+        },
+        else => unreachable,
+    }
+}
+
 fn mirConditionalSelect(emit: *Emit, inst: Mir.Inst.Index) !void {
     const tag = emit.mir.instructions.items(.tag)[inst];
     switch (tag) {
@@ -1013,10 +1060,10 @@ fn mirDataProcessing3Source(emit: *Emit, inst: Mir.Inst.Index) !void {
 
     switch (tag) {
         .mul => try emit.writeInstruction(Instruction.mul(rrr.rd, rrr.rn, rrr.rm)),
-        .smulh => try emit.writeInstruction(Instruction.smulh(rrr.rd, rrr.rn, rrr.rm)),
-        .smull => try emit.writeInstruction(Instruction.smull(rrr.rd, rrr.rn, rrr.rm)),
-        .umulh => try emit.writeInstruction(Instruction.umulh(rrr.rd, rrr.rn, rrr.rm)),
-        .umull => try emit.writeInstruction(Instruction.umull(rrr.rd, rrr.rn, rrr.rm)),
+        .smulh => try emit.writeInstruction(Instruction.smulh(rrr.rd.to64(), rrr.rn.to64(), rrr.rm.to64())),
+        .smull => try emit.writeInstruction(Instruction.smull(rrr.rd.to64(), rrr.rn.to32(), rrr.rm.to32())),
+        .umulh => try emit.writeInstruction(Instruction.umulh(rrr.rd.to64(), rrr.rn.to64(), rrr.rm.to64())),
+        .umull => try emit.writeInstruction(Instruction.umull(rrr.rd.to64(), rrr.rn.to32(), rrr.rm.to32())),
         else => unreachable,
     }
 }
src/arch/aarch64/Mir.zig
@@ -32,6 +32,10 @@ pub const Inst = struct {
         add_shifted_register,
         /// Add, update condition flags (shifted register)
         adds_shifted_register,
+        /// Add (extended register)
+        add_extended_register,
+        /// Add, update condition flags (extended register)
+        adds_extended_register,
         /// Bitwise AND (shifted register)
         and_shifted_register,
         /// Arithmetic Shift Right (immediate)
@@ -56,6 +60,8 @@ pub const Inst = struct {
         cmp_immediate,
         /// Compare (shifted register)
         cmp_shifted_register,
+        /// Compare (extended register)
+        cmp_extended_register,
         /// Conditional set
         cset,
         /// Pseudo-instruction: End of prologue
@@ -184,6 +190,10 @@ pub const Inst = struct {
         sub_shifted_register,
         /// Subtract, update condition flags (shifted register)
         subs_shifted_register,
+        /// Subtract (extended register)
+        sub_extended_register,
+        /// Subtract, update condition flags (extended register)
+        subs_extended_register,
         /// Supervisor Call
         svc,
         /// Test bits (immediate)
@@ -300,6 +310,15 @@ pub const Inst = struct {
             imm6: u6,
             shift: bits.Instruction.AddSubtractShiftedRegisterShift,
         },
+        /// Two registers with sign-extension (extension type and 3-bit shift amount)
+        ///
+        /// Used by e.g. cmp_extended_register
+        rr_extend_shift: struct {
+            rn: Register,
+            rm: Register,
+            ext_type: bits.Instruction.AddSubtractExtendedRegisterOption,
+            imm3: u3,
+        },
         /// Two registers and a shift (logical instruction version)
         /// (shift type and 6-bit amount)
         ///
@@ -356,6 +375,16 @@ pub const Inst = struct {
             imm6: u6,
             shift: bits.Instruction.AddSubtractShiftedRegisterShift,
         },
+        /// Three registers with sign-extension (extension type and 3-bit shift amount)
+        ///
+        /// Used by e.g. add_extended_register
+        rrr_extend_shift: struct {
+            rd: Register,
+            rn: Register,
+            rm: Register,
+            ext_type: bits.Instruction.AddSubtractExtendedRegisterOption,
+            imm3: u3,
+        },
         /// Three registers and a shift (logical instruction version)
         /// (shift type and 6-bit amount)
         ///