Commit b28c966e33

David Rubin <daviru007@icloud.com>
2024-03-25 13:15:02
riscv: fix overflow checks in addition.
1 parent e70584e
Changed files (3)
src/arch/riscv64/CodeGen.zig
@@ -1019,6 +1019,7 @@ fn binOpRegister(
         .add => .add,
         .sub => .sub,
         .cmp_eq => .cmp_eq,
+        .cmp_neq => .cmp_neq,
         .cmp_gt => .cmp_gt,
         .cmp_gte => .cmp_gte,
         .cmp_lt => .cmp_lt,
@@ -1185,6 +1186,8 @@ fn airAddWithOverflow(self: *Self, inst: Air.Inst.Index) !void {
         const rhs_ty = self.typeOf(extra.rhs);
 
         const add_result_mcv = try self.binOp(.add, null, lhs, rhs, lhs_ty, rhs_ty);
+        const add_result_lock = self.register_manager.lockRegAssumeUnused(add_result_mcv.register);
+        defer self.register_manager.unlockReg(add_result_lock);
 
         const tuple_ty = self.typeOfIndex(inst);
         const int_info = lhs_ty.intInfo(mod);
@@ -1196,15 +1199,44 @@ fn airAddWithOverflow(self: *Self, inst: Air.Inst.Index) !void {
 
         const result_offset = tuple_ty.structFieldOffset(0, mod) + offset;
 
-        // set the result first as we don't have a lock on the add_result_mcv register and it will
-        // get clobbered in the next binOp.
         try self.genSetStack(lhs_ty, @intCast(result_offset), add_result_mcv);
 
         if (int_info.bits >= 8 and math.isPowerOfTwo(int_info.bits)) {
             if (int_info.signedness == .unsigned) {
                 const overflow_offset = tuple_ty.structFieldOffset(1, mod) + offset;
 
-                const overflow_mcv = try self.binOp(.cmp_lt, null, add_result_mcv, lhs, lhs_ty, lhs_ty);
+                const max_val = std.math.pow(u16, 2, int_info.bits) - 1;
+
+                const overflow_reg, const overflow_lock = try self.allocReg();
+                defer self.register_manager.unlockReg(overflow_lock);
+
+                const add_reg, const add_lock = blk: {
+                    if (add_result_mcv == .register) break :blk .{ add_result_mcv.register, null };
+
+                    const add_reg, const add_lock = try self.allocReg();
+                    try self.genSetReg(lhs_ty, add_reg, add_result_mcv);
+                    break :blk .{ add_reg, add_lock };
+                };
+                defer if (add_lock) |lock| self.register_manager.unlockReg(lock);
+
+                _ = try self.addInst(.{
+                    .tag = .andi,
+                    .data = .{ .i_type = .{
+                        .rd = overflow_reg,
+                        .rs1 = add_reg,
+                        .imm12 = @intCast(max_val),
+                    } },
+                });
+
+                const overflow_mcv = try self.binOp(
+                    .cmp_neq,
+                    null,
+                    .{ .register = overflow_reg },
+                    .{ .register = add_reg },
+                    lhs_ty,
+                    lhs_ty,
+                );
+
                 try self.genSetStack(Type.u1, @intCast(overflow_offset), overflow_mcv);
 
                 break :result result_mcv;
@@ -3042,7 +3074,15 @@ fn genSetReg(self: *Self, ty: Type, reg: Register, src_val: MCValue) InnerError!
 
 fn airIntFromPtr(self: *Self, inst: Air.Inst.Index) !void {
     const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op;
-    const result = try self.resolveInst(un_op);
+    const result = result: {
+        const src_mcv = try self.resolveInst(un_op);
+        if (self.reuseOperand(inst, un_op, 0, src_mcv)) break :result src_mcv;
+
+        const dst_mcv = try self.allocRegOrMem(inst, true);
+        const dst_ty = self.typeOfIndex(inst);
+        try self.setValue(dst_ty, dst_mcv, src_mcv);
+        break :result dst_mcv;
+    };
     return self.finishAir(inst, result, .{ un_op, .none, .none });
 }
 
src/arch/riscv64/Emit.zig
@@ -58,6 +58,7 @@ pub fn emitMir(
             .@"or" => try emit.mirRType(inst),
 
             .cmp_eq => try emit.mirRType(inst),
+            .cmp_neq => try emit.mirRType(inst),
             .cmp_gt => try emit.mirRType(inst),
             .cmp_gte => try emit.mirRType(inst),
             .cmp_lt => try emit.mirRType(inst),
@@ -68,6 +69,7 @@ pub fn emitMir(
             .bne => try emit.mirBType(inst),
 
             .addi => try emit.mirIType(inst),
+            .andi => try emit.mirIType(inst),
             .jalr => try emit.mirIType(inst),
             .abs => try emit.mirIType(inst),
 
@@ -201,10 +203,14 @@ fn mirRType(emit: *Emit, inst: Mir.Inst.Index) !void {
         .cmp_eq => {
             // rs1 == rs2
 
-            // if equal, write 0 to rd
             try emit.writeInstruction(Instruction.xor(rd, rs1, rs2));
-            // if rd == 0, set rd to 1
-            try emit.writeInstruction(Instruction.sltiu(rd, rd, 1));
+            try emit.writeInstruction(Instruction.sltiu(rd, rd, 1)); // seqz
+        },
+        .cmp_neq => {
+            // rs1 != rs2
+
+            try emit.writeInstruction(Instruction.xor(rd, rs1, rs2));
+            try emit.writeInstruction(Instruction.sltu(rd, .x0, rd)); // snez
         },
         .cmp_lt => {
             // rd = 1 if rs1 < rs2
@@ -255,6 +261,8 @@ fn mirIType(emit: *Emit, inst: Mir.Inst.Index) !void {
         .addi => try emit.writeInstruction(Instruction.addi(rd, rs1, imm12)),
         .jalr => try emit.writeInstruction(Instruction.jalr(rd, imm12, rs1)),
 
+        .andi => try emit.writeInstruction(Instruction.andi(rd, rs1, imm12)),
+
         .ld => try emit.writeInstruction(Instruction.ld(rd, imm12, rs1)),
         .lw => try emit.writeInstruction(Instruction.lw(rd, imm12, rs1)),
         .lh => try emit.writeInstruction(Instruction.lh(rd, imm12, rs1)),
@@ -515,6 +523,7 @@ fn instructionSize(emit: *Emit, inst: Mir.Inst.Index) usize {
         => 12,
 
         .cmp_eq,
+        .cmp_neq,
         .cmp_imm_eq,
         .cmp_gte,
         .load_symbol,
src/arch/riscv64/Mir.zig
@@ -57,9 +57,14 @@ pub const Inst = struct {
         /// Jumps. Uses `inst` payload.
         j,
 
+        /// Immediate and, uses i_type payload
+        andi,
+
         // NOTE: Maybe create a special data for compares that includes the ops
         /// Register `==`, uses r_type
         cmp_eq,
+        /// Register `!=`, uses r_type
+        cmp_neq,
         /// Register `>`, uses r_type
         cmp_gt,
         /// Register `<`, uses r_type