Commit f4421c01e8

Jakub Konka <kubkon@jakubkonka.com>
2022-05-04 21:20:31
aarch64: implement mul_with_overflow for ints in range 33-64 bits incl
1 parent 8715b01
Changed files (4)
src/arch/aarch64/bits.zig
@@ -1409,10 +1409,6 @@ pub const Instruction = union(enum) {
         return logicalImmediate(0b11, rd, rn, imms, immr, n);
     }
 
-    pub fn tstImmediate(rn: Register, imms: u6, immr: u6, n: u1) Instruction {
-        return andsImmediate(.xzr, rn, imms, immr, n);
-    }
-
     // Bitfield
 
     pub fn sbfm(rd: Register, rn: Register, immr: u6, imms: u6) Instruction {
@@ -1589,10 +1585,20 @@ pub const Instruction = union(enum) {
         return smaddl(rd, rn, rm, .xzr);
     }
 
+    pub fn smulh(rd: Register, rn: Register, rm: Register) Instruction {
+        assert(rd.size() == 64);
+        return dataProcessing3Source(0b00, 0b010, 0b0, rd, rn, rm, .xzr);
+    }
+
     pub fn umull(rd: Register, rn: Register, rm: Register) Instruction {
         return umaddl(rd, rn, rm, .xzr);
     }
 
+    pub fn umulh(rd: Register, rn: Register, rm: Register) Instruction {
+        assert(rd.size() == 64);
+        return dataProcessing3Source(0b00, 0b110, 0b0, rd, rn, rm, .xzr);
+    }
+
     pub fn mneg(rd: Register, rn: Register, rm: Register) Instruction {
         return msub(rd, rn, rm, .xzr);
     }
@@ -1820,9 +1826,17 @@ test "serialize instructions" {
             .expected = 0b1_00_11011_0_01_00001_0_11111_00000_00000,
         },
         .{ // tst x0, #0xffffffff00000000
-            .inst = Instruction.tstImmediate(.x0, 0b011111, 0b100000, 0b1),
+            .inst = Instruction.andsImmediate(.xzr, .x0, 0b011111, 0b100000, 0b1),
             .expected = 0b1_11_100100_1_100000_011111_00000_11111,
         },
+        .{ // umulh x0, x1, x2
+            .inst = Instruction.umulh(.x0, .x1, .x2),
+            .expected = 0b1_00_11011_1_10_00010_0_11111_00001_00000,
+        },
+        .{ // smulh x0, x1, x2
+            .inst = Instruction.smulh(.x0, .x1, .x2),
+            .expected = 0b1_00_11011_0_10_00010_0_11111_00001_00000,
+        },
     };
 
     for (testcases) |case| {
src/arch/aarch64/CodeGen.zig
@@ -1294,28 +1294,29 @@ fn binOpRegister(
     };
     defer self.register_manager.unfreezeRegs(&.{rhs_reg});
 
-    const dest_reg = switch (mir_tag) {
-        .cmp_shifted_register => undefined, // cmp has no destination register
-        .smull, .umull => blk: {
-            // TODO can we reuse anything for smull and umull?
-            const raw_reg = try self.register_manager.allocReg(null);
-            break :blk raw_reg.to64();
-        },
-        else => if (maybe_inst) |inst| blk: {
-            const bin_op = self.air.instructions.items(.data)[inst].bin_op;
-
-            if (lhs_is_register and self.reuseOperand(inst, bin_op.lhs, 0, lhs)) {
-                break :blk lhs_reg;
-            } else if (rhs_is_register and self.reuseOperand(inst, bin_op.rhs, 1, rhs)) {
-                break :blk rhs_reg;
-            } else {
-                const raw_reg = try self.register_manager.allocReg(inst);
+    const dest_reg: Register = reg: {
+        const dest_reg = switch (mir_tag) {
+            .cmp_shifted_register => undefined, // cmp has no destination register
+            else => if (maybe_inst) |inst| blk: {
+                const bin_op = self.air.instructions.items(.data)[inst].bin_op;
+
+                if (lhs_is_register and self.reuseOperand(inst, bin_op.lhs, 0, lhs)) {
+                    break :blk lhs_reg;
+                } else if (rhs_is_register and self.reuseOperand(inst, bin_op.rhs, 1, rhs)) {
+                    break :blk rhs_reg;
+                } else {
+                    const raw_reg = try self.register_manager.allocReg(inst);
+                    break :blk registerAlias(raw_reg, lhs_ty.abiSize(self.target.*));
+                }
+            } else blk: {
+                const raw_reg = try self.register_manager.allocReg(null);
                 break :blk registerAlias(raw_reg, lhs_ty.abiSize(self.target.*));
-            }
-        } else blk: {
-            const raw_reg = try self.register_manager.allocReg(null);
-            break :blk registerAlias(raw_reg, lhs_ty.abiSize(self.target.*));
-        },
+            },
+        };
+        break :reg switch (mir_tag) {
+            .smull, .umull => dest_reg.to64(),
+            else => dest_reg,
+        };
     };
 
     if (!lhs_is_register) try self.genSetReg(lhs_ty, lhs_reg, lhs);
@@ -1340,7 +1341,9 @@ fn binOpRegister(
             .shift = .lsl,
         } },
         .mul,
+        .smulh,
         .smull,
+        .umulh,
         .umull,
         .lsl_register,
         .asr_register,
@@ -1946,8 +1949,177 @@ fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) !void {
 
                     break :result MCValue{ .stack_offset = stack_offset };
                 } else if (int_info.bits <= 64) {
-                    return self.fail("TODO implement mul_with_overflow for ints", .{});
-                } else return self.fail("TODO implmenet mul_with_overflow for integers > u64/i64", .{});
+                    const stack_offset = try self.allocMem(inst, tuple_size, tuple_align);
+
+                    try self.spillCompareFlagsIfOccupied();
+                    self.compare_flags_inst = null;
+
+                    // TODO this should really be put in a helper similar to `binOpRegister`
+                    const lhs_is_register = lhs == .register;
+                    const rhs_is_register = rhs == .register;
+
+                    if (lhs_is_register) self.register_manager.freezeRegs(&.{lhs.register});
+                    if (rhs_is_register) self.register_manager.freezeRegs(&.{rhs.register});
+
+                    const lhs_reg = if (lhs_is_register) lhs.register else blk: {
+                        const raw_reg = try self.register_manager.allocReg(null);
+                        const reg = registerAlias(raw_reg, lhs_ty.abiSize(self.target.*));
+                        self.register_manager.freezeRegs(&.{reg});
+                        break :blk reg;
+                    };
+                    defer self.register_manager.unfreezeRegs(&.{lhs_reg});
+
+                    const rhs_reg = if (rhs_is_register) rhs.register else blk: {
+                        const raw_reg = try self.register_manager.allocReg(null);
+                        const reg = registerAlias(raw_reg, rhs_ty.abiAlignment(self.target.*));
+                        self.register_manager.freezeRegs(&.{reg});
+                        break :blk reg;
+                    };
+                    defer self.register_manager.unfreezeRegs(&.{rhs_reg});
+
+                    if (!lhs_is_register) try self.genSetReg(lhs_ty, lhs_reg, lhs);
+                    if (!rhs_is_register) try self.genSetReg(rhs_ty, rhs_reg, rhs);
+
+                    // TODO reuse operands
+                    const dest_reg = blk: {
+                        const raw_reg = try self.register_manager.allocReg(null);
+                        const reg = registerAlias(raw_reg, lhs_ty.abiSize(self.target.*));
+                        self.register_manager.freezeRegs(&.{reg});
+                        break :blk reg;
+                    };
+                    defer self.register_manager.unfreezeRegs(&.{dest_reg});
+
+                    switch (int_info.signedness) {
+                        .signed => {
+                            // mul dest, lhs, rhs
+                            _ = try self.addInst(.{
+                                .tag = .mul,
+                                .data = .{ .rrr = .{
+                                    .rd = dest_reg,
+                                    .rn = lhs_reg,
+                                    .rm = rhs_reg,
+                                } },
+                            });
+
+                            const dest_high_reg = try self.register_manager.allocReg(null);
+                            self.register_manager.freezeRegs(&.{dest_high_reg});
+                            defer self.register_manager.unfreezeRegs(&.{dest_high_reg});
+
+                            // smulh dest_high, lhs, rhs
+                            _ = try self.addInst(.{
+                                .tag = .smulh,
+                                .data = .{ .rrr = .{
+                                    .rd = dest_high_reg,
+                                    .rn = lhs_reg,
+                                    .rm = rhs_reg,
+                                } },
+                            });
+
+                            // cmp dest_high, dest, asr #63
+                            _ = try self.addInst(.{
+                                .tag = .cmp_shifted_register,
+                                .data = .{ .rr_imm6_shift = .{
+                                    .rn = dest_high_reg,
+                                    .rm = dest_reg,
+                                    .imm6 = 63,
+                                    .shift = .asr,
+                                } },
+                            });
+
+                            const shift: u6 = @intCast(u6, @as(u7, 64) - @intCast(u7, int_info.bits));
+                            if (shift > 0) {
+                                // lsl dest_high, dest, #shift
+                                _ = try self.addInst(.{
+                                    .tag = .lsl_immediate,
+                                    .data = .{ .rr_shift = .{
+                                        .rd = dest_high_reg,
+                                        .rn = dest_reg,
+                                        .shift = shift,
+                                    } },
+                                });
+
+                                // cmp dest, dest_high, #shift
+                                _ = try self.addInst(.{
+                                    .tag = .cmp_shifted_register,
+                                    .data = .{ .rr_imm6_shift = .{
+                                        .rn = dest_reg,
+                                        .rm = dest_high_reg,
+                                        .imm6 = shift,
+                                        .shift = .asr,
+                                    } },
+                                });
+                            }
+                        },
+                        .unsigned => {
+                            const dest_high_reg = try self.register_manager.allocReg(null);
+                            self.register_manager.freezeRegs(&.{dest_high_reg});
+                            defer self.register_manager.unfreezeRegs(&.{dest_high_reg});
+
+                            // umulh dest_high, lhs, rhs
+                            _ = try self.addInst(.{
+                                .tag = .umulh,
+                                .data = .{ .rrr = .{
+                                    .rd = dest_high_reg,
+                                    .rn = lhs_reg,
+                                    .rm = rhs_reg,
+                                } },
+                            });
+
+                            // mul dest, lhs, rhs
+                            _ = try self.addInst(.{
+                                .tag = .mul,
+                                .data = .{ .rrr = .{
+                                    .rd = dest_reg,
+                                    .rn = lhs_reg,
+                                    .rm = rhs_reg,
+                                } },
+                            });
+
+                            _ = try self.binOp(
+                                .cmp_eq,
+                                null,
+                                .{ .register = dest_high_reg },
+                                .{ .immediate = 0 },
+                                Type.usize,
+                                Type.usize,
+                            );
+
+                            if (int_info.bits < 64) {
+                                // lsr dest_high, dest, #shift
+                                _ = try self.addInst(.{
+                                    .tag = .lsr_immediate,
+                                    .data = .{ .rr_shift = .{
+                                        .rd = dest_high_reg,
+                                        .rn = dest_reg,
+                                        .shift = @intCast(u6, int_info.bits),
+                                    } },
+                                });
+
+                                _ = try self.binOp(
+                                    .cmp_eq,
+                                    null,
+                                    .{ .register = dest_high_reg },
+                                    .{ .immediate = 0 },
+                                    Type.usize,
+                                    Type.usize,
+                                );
+                            }
+                        },
+                    }
+
+                    const truncated_reg = try self.register_manager.allocReg(null);
+                    self.register_manager.freezeRegs(&.{truncated_reg});
+                    defer self.register_manager.unfreezeRegs(&.{truncated_reg});
+
+                    try self.truncRegister(dest_reg, truncated_reg, int_info.signedness, int_info.bits);
+
+                    try self.genSetStack(lhs_ty, stack_offset, .{ .register = truncated_reg });
+                    try self.genSetStack(Type.initTag(.u1), stack_offset - overflow_bit_offset, .{
+                        .compare_flags_unsigned = .neq,
+                    });
+
+                    break :result MCValue{ .stack_offset = stack_offset };
+                } else return self.fail("TODO implement mul_with_overflow for integers > u64/i64", .{});
             },
             else => unreachable,
         }
src/arch/aarch64/Emit.zig
@@ -167,7 +167,9 @@ pub fn emitMir(
             .movz => try emit.mirMoveWideImmediate(inst),
 
             .mul => try emit.mirDataProcessing3Source(inst),
+            .smulh => try emit.mirDataProcessing3Source(inst),
             .smull => try emit.mirDataProcessing3Source(inst),
+            .umulh => try emit.mirDataProcessing3Source(inst),
             .umull => try emit.mirDataProcessing3Source(inst),
 
             .nop => try emit.mirNop(),
@@ -677,7 +679,14 @@ fn mirLogicalImmediate(emit: *Emit, inst: Mir.Inst.Index) !void {
 
     switch (tag) {
         .eor_immediate => try emit.writeInstruction(Instruction.eorImmediate(rd, rn, imms, immr, n)),
-        .tst_immediate => try emit.writeInstruction(Instruction.tstImmediate(rn, imms, immr, n)),
+        .tst_immediate => {
+            const zr: Register = switch (rd.size()) {
+                32 => .wzr,
+                64 => .xzr,
+                else => unreachable,
+            };
+            try emit.writeInstruction(Instruction.andsImmediate(zr, rn, imms, immr, n));
+        },
         else => unreachable,
     }
 }
@@ -1004,7 +1013,9 @@ fn mirDataProcessing3Source(emit: *Emit, inst: Mir.Inst.Index) !void {
 
     switch (tag) {
         .mul => try emit.writeInstruction(Instruction.mul(rrr.rd, rrr.rn, rrr.rm)),
+        .smulh => try emit.writeInstruction(Instruction.smulh(rrr.rd, rrr.rn, rrr.rm)),
         .smull => try emit.writeInstruction(Instruction.smull(rrr.rd, rrr.rn, rrr.rm)),
+        .umulh => try emit.writeInstruction(Instruction.umulh(rrr.rd, rrr.rn, rrr.rm)),
         .umull => try emit.writeInstruction(Instruction.umull(rrr.rd, rrr.rn, rrr.rm)),
         else => unreachable,
     }
src/arch/aarch64/Mir.zig
@@ -146,6 +146,8 @@ pub const Inst = struct {
         ret,
         /// Signed bitfield extract
         sbfx,
+        /// Signed multiply high
+        smulh,
         /// Signed multiply long
         smull,
         /// Signed extend byte
@@ -188,6 +190,8 @@ pub const Inst = struct {
         tst_immediate,
         /// Unsigned bitfield extract
         ubfx,
+        /// Unsigned multiply high
+        umulh,
         /// Unsigned multiply long
         umull,
         /// Unsigned extend byte