Commit 5e010b6dea

David Rubin <daviru007@icloud.com>
2024-03-23 04:14:10
riscv: reorganize `binOp` and implement `cmp_imm_gte` MIR
this was an annoying one to do, as there is no (to my knowledge) myriad sequence that will allow us to do `gte` compares with an immediate without allocating a register. RISC-V provides a single instruction to do compares, that being `lt`, and so you need to use more than one for other variants, but in this case, i believe you need to allocate a register.
1 parent 63bbf66
Changed files (3)
src/arch/riscv64/CodeGen.zig
@@ -850,6 +850,122 @@ fn airSlice(self: *Self, inst: Air.Inst.Index) !void {
     return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
 }
 
+fn airBinOp(self: *Self, inst: Air.Inst.Index, tag: Air.Inst.Tag) !void {
+    const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
+    const lhs = try self.resolveInst(bin_op.lhs);
+    const rhs = try self.resolveInst(bin_op.rhs);
+    const lhs_ty = self.typeOf(bin_op.lhs);
+    const rhs_ty = self.typeOf(bin_op.rhs);
+
+    const result: MCValue = if (self.liveness.isUnused(inst)) .dead else try self.binOp(tag, inst, lhs, rhs, lhs_ty, rhs_ty);
+    return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
+}
+
+/// For all your binary operation needs, this function will generate
+/// the corresponding Mir instruction(s). Returns the location of the
+/// result.
+///
+/// If the binary operation itself happens to be an Air instruction,
+/// pass the corresponding index in the inst parameter. That helps
+/// this function do stuff like reusing operands.
+///
+/// This function does not do any lowering to Mir itself, but instead
+/// looks at the lhs and rhs and determines which kind of lowering
+/// would be best suitable and then delegates the lowering to other
+/// functions.
+///
+/// `maybe_inst` **needs** to be a bin_op, make sure of that.
+fn binOp(
+    self: *Self,
+    tag: Air.Inst.Tag,
+    maybe_inst: ?Air.Inst.Index,
+    lhs: MCValue,
+    rhs: MCValue,
+    lhs_ty: Type,
+    rhs_ty: Type,
+) InnerError!MCValue {
+    const mod = self.bin_file.comp.module.?;
+    switch (tag) {
+        // Arithmetic operations on integers and floats
+        .add,
+        .sub,
+        .cmp_eq,
+        .cmp_neq,
+        .cmp_gt,
+        .cmp_gte,
+        .cmp_lt,
+        .cmp_lte,
+        => {
+            switch (lhs_ty.zigTypeTag(mod)) {
+                .Float => return self.fail("TODO binary operations on floats", .{}),
+                .Vector => return self.fail("TODO binary operations on vectors", .{}),
+                .Int => {
+                    assert(lhs_ty.eql(rhs_ty, mod));
+                    const int_info = lhs_ty.intInfo(mod);
+                    if (int_info.bits <= 64) {
+                        if (rhs == .immediate) {
+                            return self.binOpImm(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
+                        }
+                        return self.binOpRegister(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
+                    } else {
+                        return self.fail("TODO binary operations on int with bits > 64", .{});
+                    }
+                },
+                else => unreachable,
+            }
+        },
+        .ptr_add,
+        .ptr_sub,
+        => {
+            switch (lhs_ty.zigTypeTag(mod)) {
+                .Pointer => {
+                    const ptr_ty = lhs_ty;
+                    const elem_ty = switch (ptr_ty.ptrSize(mod)) {
+                        .One => ptr_ty.childType(mod).childType(mod), // ptr to array, so get array element type
+                        else => ptr_ty.childType(mod),
+                    };
+                    const elem_size = elem_ty.abiSize(mod);
+
+                    if (elem_size == 1) {
+                        const base_tag: Air.Inst.Tag = switch (tag) {
+                            .ptr_add => .add,
+                            .ptr_sub => .sub,
+                            else => unreachable,
+                        };
+
+                        return try self.binOpRegister(base_tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
+                    } else {
+                        return self.fail("TODO ptr_add with elem_size > 1", .{});
+                    }
+                },
+                else => unreachable,
+            }
+        },
+
+        // These instructions have unsymteric bit sizes on RHS and LHS.
+        .shr,
+        .shl,
+        => {
+            switch (lhs_ty.zigTypeTag(mod)) {
+                .Float => return self.fail("TODO binary operations on floats", .{}),
+                .Vector => return self.fail("TODO binary operations on vectors", .{}),
+                .Int => {
+                    const int_info = lhs_ty.intInfo(mod);
+                    if (int_info.bits <= 64) {
+                        if (rhs == .immediate) {
+                            return self.binOpImm(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
+                        }
+                        return self.binOpRegister(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
+                    } else {
+                        return self.fail("TODO binary operations on int with bits > 64", .{});
+                    }
+                },
+                else => unreachable,
+            }
+        },
+        else => unreachable,
+    }
+}
 /// Don't call this function directly. Use binOp instead.
 ///
 /// Calling this function signals an intention to generate a Mir
@@ -963,7 +1079,6 @@ fn binOpImm(
     lhs_ty: Type,
     rhs_ty: Type,
 ) !MCValue {
-    _ = rhs_ty;
     assert(rhs == .immediate);
 
     const lhs_is_register = lhs == .register;
@@ -1006,142 +1121,44 @@ fn binOpImm(
     const mir_tag: Mir.Inst.Tag = switch (tag) {
         .shl => .slli,
         .shr => .srli,
+        .cmp_gte => .cmp_imm_gte,
         else => return self.fail("TODO: binOpImm {s}", .{@tagName(tag)}),
     };
 
-    _ = try self.addInst(.{
-        .tag = mir_tag,
-        .data = .{
-            .i_type = .{
-                .rd = dest_reg,
-                .rs1 = lhs_reg,
-                .imm12 = math.cast(i12, rhs.immediate) orelse {
-                    return self.fail("TODO: binOpImm larger than i12 i_type payload", .{});
-                },
-            },
-        },
-    });
-
-    // generate the struct for OF checks
-
-    return MCValue{ .register = dest_reg };
-}
-
-/// For all your binary operation needs, this function will generate
-/// the corresponding Mir instruction(s). Returns the location of the
-/// result.
-///
-/// If the binary operation itself happens to be an Air instruction,
-/// pass the corresponding index in the inst parameter. That helps
-/// this function do stuff like reusing operands.
-///
-/// This function does not do any lowering to Mir itself, but instead
-/// looks at the lhs and rhs and determines which kind of lowering
-/// would be best suitable and then delegates the lowering to other
-/// functions.
-///
-/// `maybe_inst` **needs** to be a bin_op, make sure of that.
-fn binOp(
-    self: *Self,
-    tag: Air.Inst.Tag,
-    maybe_inst: ?Air.Inst.Index,
-    lhs: MCValue,
-    rhs: MCValue,
-    lhs_ty: Type,
-    rhs_ty: Type,
-) InnerError!MCValue {
-    const mod = self.bin_file.comp.module.?;
-    switch (tag) {
-        // Arithmetic operations on integers and floats
-        .add,
-        .sub,
-        .cmp_eq,
-        .cmp_neq,
-        .cmp_gt,
-        .cmp_gte,
-        .cmp_lt,
-        .cmp_lte,
+    // apply some special operations needed
+    switch (mir_tag) {
+        .slli,
+        .srli,
         => {
-            switch (lhs_ty.zigTypeTag(mod)) {
-                .Float => return self.fail("TODO binary operations on floats", .{}),
-                .Vector => return self.fail("TODO binary operations on vectors", .{}),
-                .Int => {
-                    assert(lhs_ty.eql(rhs_ty, mod));
-                    const int_info = lhs_ty.intInfo(mod);
-                    if (int_info.bits <= 64) {
-                        if (rhs == .immediate) {
-                            return self.binOpImm(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
-                        }
-                        return self.binOpRegister(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
-                    } else {
-                        return self.fail("TODO binary operations on int with bits > 64", .{});
-                    }
-                },
-                else => unreachable,
-            }
-        },
-        .ptr_add,
-        .ptr_sub,
-        => {
-            switch (lhs_ty.zigTypeTag(mod)) {
-                .Pointer => {
-                    const ptr_ty = lhs_ty;
-                    const elem_ty = switch (ptr_ty.ptrSize(mod)) {
-                        .One => ptr_ty.childType(mod).childType(mod), // ptr to array, so get array element type
-                        else => ptr_ty.childType(mod),
-                    };
-                    const elem_size = elem_ty.abiSize(mod);
-
-                    if (elem_size == 1) {
-                        const base_tag: Air.Inst.Tag = switch (tag) {
-                            .ptr_add => .add,
-                            .ptr_sub => .sub,
-                            else => unreachable,
-                        };
-
-                        return try self.binOpRegister(base_tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
-                    } else {
-                        return self.fail("TODO ptr_add with elem_size > 1", .{});
-                    }
-                },
-                else => unreachable,
-            }
+            _ = try self.addInst(.{
+                .tag = mir_tag,
+                .data = .{ .i_type = .{
+                    .rd = dest_reg,
+                    .rs1 = lhs_reg,
+                    .imm12 = math.cast(i12, rhs.immediate) orelse {
+                        return self.fail("TODO: binOpImm larger than i12 i_type payload", .{});
+                    },
+                } },
+            });
         },
+        .cmp_imm_gte => {
+            const imm_reg = try self.copyToTmpRegister(rhs_ty, .{ .immediate = rhs.immediate - 1 });
 
-        // These instructions have unsymteric bit sizes.
-        .shr,
-        .shl,
-        => {
-            switch (lhs_ty.zigTypeTag(mod)) {
-                .Float => return self.fail("TODO binary operations on floats", .{}),
-                .Vector => return self.fail("TODO binary operations on vectors", .{}),
-                .Int => {
-                    const int_info = lhs_ty.intInfo(mod);
-                    if (int_info.bits <= 64) {
-                        if (rhs == .immediate) {
-                            return self.binOpImm(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
-                        }
-                        return self.binOpRegister(tag, maybe_inst, lhs, rhs, lhs_ty, rhs_ty);
-                    } else {
-                        return self.fail("TODO binary operations on int with bits > 64", .{});
-                    }
-                },
-                else => unreachable,
-            }
+            _ = try self.addInst(.{
+                .tag = mir_tag,
+                .data = .{ .r_type = .{
+                    .rd = dest_reg,
+                    .rs1 = imm_reg,
+                    .rs2 = lhs_reg,
+                } },
+            });
         },
         else => unreachable,
     }
-}
 
-fn airBinOp(self: *Self, inst: Air.Inst.Index, tag: Air.Inst.Tag) !void {
-    const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
-    const lhs = try self.resolveInst(bin_op.lhs);
-    const rhs = try self.resolveInst(bin_op.rhs);
-    const lhs_ty = self.typeOf(bin_op.lhs);
-    const rhs_ty = self.typeOf(bin_op.rhs);
+    // generate the struct for overflow checks
 
-    const result: MCValue = if (self.liveness.isUnused(inst)) .dead else try self.binOp(tag, inst, lhs, rhs, lhs_ty, rhs_ty);
-    return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
+    return MCValue{ .register = dest_reg };
 }
 
 fn airPtrArithmetic(self: *Self, inst: Air.Inst.Index, tag: Air.Inst.Tag) !void {
@@ -2101,8 +2118,12 @@ fn airCondBr(self: *Self, inst: Air.Inst.Index) !void {
     const else_body: []const Air.Inst.Index = @ptrCast(self.air.extra[extra.end + then_body.len ..][0..extra.data.else_body_len]);
     const liveness_condbr = self.liveness.getCondBr(inst);
 
-    // A branch to the false section. Uses beq
-    const reloc = try self.condBr(cond_ty, cond);
+    const cond_reg = try self.register_manager.allocReg(inst, gp);
+    const cond_reg_lock = self.register_manager.lockRegAssumeUnused(cond_reg);
+    defer self.register_manager.unlockReg(cond_reg_lock);
+
+    // A branch to the false section. Uses bne
+    const reloc = try self.condBr(cond_ty, cond, cond_reg);
 
     // If the condition dies here in this condbr instruction, process
     // that death now instead of later as this has an effect on
@@ -2233,19 +2254,14 @@ fn airCondBr(self: *Self, inst: Air.Inst.Index) !void {
     }
 }
 
-fn condBr(self: *Self, cond_ty: Type, condition: MCValue) !Mir.Inst.Index {
-    _ = cond_ty;
-
-    const reg = switch (condition) {
-        .register => |r| r,
-        else => try self.copyToTmpRegister(Type.bool, condition),
-    };
+fn condBr(self: *Self, cond_ty: Type, condition: MCValue, cond_reg: Register) !Mir.Inst.Index {
+    try self.genSetReg(cond_ty, cond_reg, condition);
 
     return try self.addInst(.{
         .tag = .bne,
         .data = .{
             .b_type = .{
-                .rs1 = reg,
+                .rs1 = cond_reg,
                 .rs2 = .zero,
                 .inst = undefined,
             },
@@ -2739,6 +2755,7 @@ fn genSetStack(self: *Self, ty: Type, stack_offset: u32, src_val: MCValue) Inner
                         } else return self.fail("TODO genSetStack for {s}", .{@tagName(self.bin_file.tag)});
                     };
 
+                    // setup the src pointer
                     _ = try self.addInst(.{
                         .tag = .load_symbol,
                         .data = .{
@@ -2789,7 +2806,7 @@ fn genInlineMemcpy(
 
     // compare count to length
     const compare_inst = try self.addInst(.{
-        .tag = .cmp_gt,
+        .tag = .cmp_eq,
         .data = .{ .r_type = .{
             .rd = tmp,
             .rs1 = count,
@@ -2861,9 +2878,12 @@ fn genSetReg(self: *Self, ty: Type, reg: Register, src_val: MCValue) InnerError!
                     } },
                 });
             } else {
+                // TODO: use a more advanced myriad seq to do this without a reg.
+                // see: https://github.com/llvm/llvm-project/blob/081a66ffacfe85a37ff775addafcf3371e967328/llvm/lib/Target/RISCV/MCTargetDesc/RISCVMatInt.cpp#L224
+
                 const temp = try self.register_manager.allocReg(null, gp);
-                const maybe_temp_lock = self.register_manager.lockReg(temp);
-                defer if (maybe_temp_lock) |temp_lock| self.register_manager.unlockReg(temp_lock);
+                const temp_lock = self.register_manager.lockRegAssumeUnused(temp);
+                defer self.register_manager.unlockReg(temp_lock);
 
                 const lo32: i32 = @truncate(x);
                 const carry: i32 = if (lo32 < 0) 1 else 0;
src/arch/riscv64/Emit.zig
@@ -59,6 +59,7 @@ pub fn emitMir(
 
             .cmp_eq => try emit.mirRType(inst),
             .cmp_gt => try emit.mirRType(inst),
+            .cmp_imm_gte => try emit.mirRType(inst),
 
             .beq => try emit.mirBType(inst),
             .bne => try emit.mirBType(inst),
@@ -185,14 +186,27 @@ fn mirRType(emit: *Emit, inst: Mir.Inst.Index) !void {
     switch (tag) {
         .add => try emit.writeInstruction(Instruction.add(rd, rs1, rs2)),
         .sub => try emit.writeInstruction(Instruction.sub(rd, rs1, rs2)),
-        .cmp_gt => try emit.writeInstruction(Instruction.slt(rd, rs1, rs2)),
+        .cmp_gt => {
+            // rs1 > rs2
+            try emit.writeInstruction(Instruction.slt(rd, rs1, rs2));
+        },
         .cmp_eq => {
+            // rs1 == rs2
+
+            // if equal, write 0 to rd
             try emit.writeInstruction(Instruction.xor(rd, rs1, rs2));
+            // if rd == 0, set rd to 1
             try emit.writeInstruction(Instruction.sltiu(rd, rd, 1));
         },
         .sllw => try emit.writeInstruction(Instruction.sllw(rd, rs1, rs2)),
         .srlw => try emit.writeInstruction(Instruction.srlw(rd, rs1, rs2)),
         .@"or" => try emit.writeInstruction(Instruction.@"or"(rd, rs1, rs2)),
+        .cmp_imm_gte => {
+            // rd = rs1 >= imm12
+            // see the docstring for cmp_imm_gte to see why we use r_type here
+            try emit.writeInstruction(Instruction.slt(rd, rs1, rs2));
+            try emit.writeInstruction(Instruction.xori(rd, rd, 1));
+        },
         else => unreachable,
     }
 }
@@ -220,30 +234,34 @@ fn mirIType(emit: *Emit, inst: Mir.Inst.Index) !void {
     const tag = emit.mir.instructions.items(.tag)[inst];
     const i_type = emit.mir.instructions.items(.data)[inst].i_type;
 
+    const rd = i_type.rd;
+    const rs1 = i_type.rs1;
+    const imm12 = i_type.imm12;
+
     switch (tag) {
-        .addi => try emit.writeInstruction(Instruction.addi(i_type.rd, i_type.rs1, i_type.imm12)),
-        .jalr => try emit.writeInstruction(Instruction.jalr(i_type.rd, i_type.imm12, i_type.rs1)),
+        .addi => try emit.writeInstruction(Instruction.addi(rd, rs1, imm12)),
+        .jalr => try emit.writeInstruction(Instruction.jalr(rd, imm12, rs1)),
 
-        .ld => try emit.writeInstruction(Instruction.ld(i_type.rd, i_type.imm12, i_type.rs1)),
-        .lw => try emit.writeInstruction(Instruction.lw(i_type.rd, i_type.imm12, i_type.rs1)),
-        .lh => try emit.writeInstruction(Instruction.lh(i_type.rd, i_type.imm12, i_type.rs1)),
-        .lb => try emit.writeInstruction(Instruction.lb(i_type.rd, i_type.imm12, i_type.rs1)),
+        .ld => try emit.writeInstruction(Instruction.ld(rd, imm12, rs1)),
+        .lw => try emit.writeInstruction(Instruction.lw(rd, imm12, rs1)),
+        .lh => try emit.writeInstruction(Instruction.lh(rd, imm12, rs1)),
+        .lb => try emit.writeInstruction(Instruction.lb(rd, imm12, rs1)),
 
-        .sd => try emit.writeInstruction(Instruction.sd(i_type.rd, i_type.imm12, i_type.rs1)),
-        .sw => try emit.writeInstruction(Instruction.sw(i_type.rd, i_type.imm12, i_type.rs1)),
-        .sh => try emit.writeInstruction(Instruction.sh(i_type.rd, i_type.imm12, i_type.rs1)),
-        .sb => try emit.writeInstruction(Instruction.sb(i_type.rd, i_type.imm12, i_type.rs1)),
+        .sd => try emit.writeInstruction(Instruction.sd(rd, imm12, rs1)),
+        .sw => try emit.writeInstruction(Instruction.sw(rd, imm12, rs1)),
+        .sh => try emit.writeInstruction(Instruction.sh(rd, imm12, rs1)),
+        .sb => try emit.writeInstruction(Instruction.sb(rd, imm12, rs1)),
 
-        .ldr_ptr_stack => try emit.writeInstruction(Instruction.add(i_type.rd, i_type.rs1, .sp)),
+        .ldr_ptr_stack => try emit.writeInstruction(Instruction.add(rd, rs1, .sp)),
 
         .abs => {
-            try emit.writeInstruction(Instruction.sraiw(i_type.rd, i_type.rs1, @intCast(i_type.imm12)));
-            try emit.writeInstruction(Instruction.xor(i_type.rs1, i_type.rs1, i_type.rd));
-            try emit.writeInstruction(Instruction.subw(i_type.rs1, i_type.rs1, i_type.rd));
+            try emit.writeInstruction(Instruction.sraiw(rd, rs1, @intCast(imm12)));
+            try emit.writeInstruction(Instruction.xor(rs1, rs1, rd));
+            try emit.writeInstruction(Instruction.subw(rs1, rs1, rd));
         },
 
-        .srli => try emit.writeInstruction(Instruction.srli(i_type.rd, i_type.rs1, @intCast(i_type.imm12))),
-        .slli => try emit.writeInstruction(Instruction.slli(i_type.rd, i_type.rs1, @intCast(i_type.imm12))),
+        .srli => try emit.writeInstruction(Instruction.srli(rd, rs1, @intCast(imm12))),
+        .slli => try emit.writeInstruction(Instruction.slli(rd, rs1, @intCast(imm12))),
 
         else => unreachable,
     }
@@ -471,12 +489,13 @@ fn instructionSize(emit: *Emit, inst: Mir.Inst.Index) usize {
         .dbg_prologue_end,
         => 0,
 
-        .psuedo_epilogue => 12, // 3 * 4
-        .psuedo_prologue => 16, // 4 * 4
+        .psuedo_epilogue => 12,
+        .psuedo_prologue => 16,
 
-        .abs => 12, // 3 * 4
+        .abs => 12,
 
         .cmp_eq => 8,
+        .cmp_imm_gte => 8,
 
         else => 4,
     };
src/arch/riscv64/Mir.zig
@@ -57,12 +57,21 @@ pub const Inst = struct {
         /// Jumps. Uses `inst` payload.
         j,
 
-        // TODO: Maybe create a special data for compares that includes the ops
-        /// Compare equal, uses r_type
+        // NOTE: Maybe create a special data for compares that includes the ops
+        /// Register `==`, uses r_type
         cmp_eq,
-        /// Compare greater than, uses r_type
+        /// Register `>`, uses r_type
         cmp_gt,
 
+        /// Immediate `>=`, uses r_type
+        ///
+        /// Note: this uses r_type because RISC-V does not provide a good way
+        /// to do `>=` comparisons on immediates. Usually we would just subtract
+        /// 1 from the immediate and do a `>` comparison, however there is no `>`
+        /// register to immedate comparison in RISC-V. This leads us to need to
+        /// allocate a register for temporary use.
+        cmp_imm_gte,
+
         /// Branch if equal Uses b_type
         beq,
         /// Branch if not eql Uses b_type