Commit 3c0015c828

David Rubin <daviru007@icloud.com>
2024-03-25 18:33:34
riscv: implement a basic `@intCast`
the truncation panic logic is generated in Sema, so I don't need to roll anything of my own. I add all of the boilerplate for that detecting the truncation and it works in basic test cases!
1 parent 685f828
Changed files (4)
src/arch/riscv64/CodeGen.zig
@@ -797,23 +797,57 @@ fn airFpext(self: *Self, inst: Air.Inst.Index) !void {
 }
 
 fn airIntCast(self: *Self, inst: Air.Inst.Index) !void {
+    const mod = self.bin_file.comp.module.?;
     const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
-    if (self.liveness.isUnused(inst))
-        return self.finishAir(inst, .dead, .{ ty_op.operand, .none, .none });
+    const src_ty = self.typeOf(ty_op.operand);
+    const dst_ty = self.typeOfIndex(inst);
 
-    const mod = self.bin_file.comp.module.?;
-    const operand_ty = self.typeOf(ty_op.operand);
-    const operand = try self.resolveInst(ty_op.operand);
-    const info_a = operand_ty.intInfo(mod);
-    const info_b = self.typeOfIndex(inst).intInfo(mod);
-    if (info_a.signedness != info_b.signedness)
-        return self.fail("TODO gen intcast sign safety in semantic analysis", .{});
+    const result: MCValue = result: {
+        const dst_abi_size: u32 = @intCast(dst_ty.abiSize(mod));
 
-    if (info_a.bits == info_b.bits)
-        return self.finishAir(inst, operand, .{ ty_op.operand, .none, .none });
+        const src_int_info = src_ty.intInfo(mod);
+        const dst_int_info = dst_ty.intInfo(mod);
+        const extend = switch (src_int_info.signedness) {
+            .signed => dst_int_info,
+            .unsigned => src_int_info,
+        }.signedness;
 
-    return self.fail("TODO implement intCast for {}", .{self.target.cpu.arch});
-    // return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
+        _ = dst_abi_size;
+        _ = extend;
+
+        const min_ty = if (dst_int_info.bits < src_int_info.bits) dst_ty else src_ty;
+
+        const src_mcv = try self.resolveInst(ty_op.operand);
+
+        const src_storage_bits: u16 = switch (src_mcv) {
+            .register => 64,
+            .stack_offset => src_int_info.bits,
+            else => return self.fail("airIntCast from {s}", .{@tagName(src_mcv)}),
+        };
+
+        const dst_mcv = if (dst_int_info.bits <= src_storage_bits and
+            math.divCeil(u16, dst_int_info.bits, 64) catch unreachable ==
+            math.divCeil(u32, src_storage_bits, 64) catch unreachable and
+            self.reuseOperand(inst, ty_op.operand, 0, src_mcv)) src_mcv else dst: {
+            const dst_mcv = try self.allocRegOrMem(inst, true);
+            try self.setValue(min_ty, dst_mcv, src_mcv);
+            break :dst dst_mcv;
+        };
+
+        if (dst_int_info.bits <= src_int_info.bits) {
+            break :result dst_mcv;
+        }
+
+        if (dst_int_info.bits > 64 or src_int_info.bits > 64) {
+            break :result null; // TODO
+        }
+
+        break :result dst_mcv;
+    } orelse return self.fail("TODO implement airIntCast from {} to {}", .{
+        src_ty.fmt(mod), dst_ty.fmt(mod),
+    });
+
+    return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
 
 fn airTrunc(self: *Self, inst: Air.Inst.Index) !void {
@@ -1080,7 +1114,9 @@ fn binOpImm(
         .shr => .srli,
         .cmp_gte => .cmp_imm_gte,
         .cmp_eq => .cmp_imm_eq,
+        .cmp_lte => .cmp_imm_lte,
         .add => .addi,
+        .sub => .addiw,
         else => return self.fail("TODO: binOpImm {s}", .{@tagName(tag)}),
     };
 
@@ -1090,6 +1126,7 @@ fn binOpImm(
         .srli,
         .addi,
         .cmp_imm_eq,
+        .cmp_imm_lte,
         => {
             _ = try self.addInst(.{
                 .tag = mir_tag,
@@ -1102,6 +1139,18 @@ fn binOpImm(
                 } },
             });
         },
+        .addiw => {
+            _ = try self.addInst(.{
+                .tag = mir_tag,
+                .data = .{ .i_type = .{
+                    .rd = dest_reg,
+                    .rs1 = lhs_reg,
+                    .imm12 = -(math.cast(i12, rhs.immediate) orelse {
+                        return self.fail("TODO: binOpImm larger than i12 i_type payload", .{});
+                    }),
+                } },
+            });
+        },
         .cmp_imm_gte => {
             const imm_reg = try self.copyToTmpRegister(rhs_ty, .{ .immediate = rhs.immediate - 1 });
 
@@ -1146,7 +1195,16 @@ fn airAddSat(self: *Self, inst: Air.Inst.Index) !void {
 
 fn airSubWrap(self: *Self, inst: Air.Inst.Index) !void {
     const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
-    const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement subwrap for {}", .{self.target.cpu.arch});
+    const result: MCValue = if (self.liveness.isUnused(inst)) .dead else result: {
+        // RISCV arthemtic instructions already wrap, so this is simply a sub binOp with
+        // no overflow checks.
+        const lhs = try self.resolveInst(bin_op.lhs);
+        const rhs = try self.resolveInst(bin_op.rhs);
+        const lhs_ty = self.typeOf(bin_op.lhs);
+        const rhs_ty = self.typeOf(bin_op.rhs);
+
+        break :result try self.binOp(.sub, inst, lhs, rhs, lhs_ty, rhs_ty);
+    };
     return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
 }
 
@@ -3441,3 +3499,7 @@ fn typeOfIndex(self: *Self, inst: Air.Inst.Index) Type {
     const mod = self.bin_file.comp.module.?;
     return self.air.typeOfIndex(inst, &mod.intern_pool);
 }
+
+fn hasFeature(self: *Self, feature: Target.riscv.Feature) bool {
+    return Target.riscv.featureSetHas(self.target.cpu.features, feature);
+}
src/arch/riscv64/Emit.zig
@@ -64,11 +64,13 @@ pub fn emitMir(
             .cmp_lt => try emit.mirRType(inst),
             .cmp_imm_gte => try emit.mirRType(inst),
             .cmp_imm_eq => try emit.mirIType(inst),
+            .cmp_imm_lte => try emit.mirIType(inst),
 
             .beq => try emit.mirBType(inst),
             .bne => try emit.mirBType(inst),
 
             .addi => try emit.mirIType(inst),
+            .addiw => try emit.mirIType(inst),
             .andi => try emit.mirIType(inst),
             .jalr => try emit.mirIType(inst),
             .abs => try emit.mirIType(inst),
@@ -259,6 +261,7 @@ fn mirIType(emit: *Emit, inst: Mir.Inst.Index) !void {
 
     switch (tag) {
         .addi => try emit.writeInstruction(Instruction.addi(rd, rs1, imm12)),
+        .addiw => try emit.writeInstruction(Instruction.addiw(rd, rs1, imm12)),
         .jalr => try emit.writeInstruction(Instruction.jalr(rd, imm12, rs1)),
 
         .andi => try emit.writeInstruction(Instruction.andi(rd, rs1, imm12)),
@@ -288,6 +291,11 @@ fn mirIType(emit: *Emit, inst: Mir.Inst.Index) !void {
             try emit.writeInstruction(Instruction.xori(rd, rs1, imm12));
             try emit.writeInstruction(Instruction.sltiu(rd, rd, 1));
         },
+
+        .cmp_imm_lte => {
+            try emit.writeInstruction(Instruction.sltiu(rd, rs1, @bitCast(imm12)));
+        },
+
         else => unreachable,
     }
 }
src/arch/riscv64/Mir.zig
@@ -25,6 +25,7 @@ pub const Inst = struct {
 
     pub const Tag = enum(u16) {
         addi,
+        addiw,
         jalr,
         lui,
         mv,
@@ -83,6 +84,8 @@ pub const Inst = struct {
 
         /// Immediate `==`, uses i_type
         cmp_imm_eq,
+        /// Immediate `<=`, uses i_typei
+        cmp_imm_lte,
 
         /// Branch if equal, Uses b_type
         beq,
src/Sema.zig
@@ -10499,9 +10499,10 @@ fn intCast(
             const dest_max_val_scalar = try dest_scalar_ty.maxIntScalar(mod, operand_scalar_ty);
             const dest_max_val = try sema.splat(operand_ty, dest_max_val_scalar);
             const dest_max = Air.internedToRef(dest_max_val.toIntern());
-            const diff = try block.addBinOp(.sub_wrap, dest_max, operand);
 
             if (actual_info.signedness == .signed) {
+                const diff = try block.addBinOp(.sub_wrap, dest_max, operand);
+
                 // Reinterpret the sign-bit as part of the value. This will make
                 // negative differences (`operand` > `dest_max`) appear too big.
                 const unsigned_scalar_operand_ty = try mod.intType(.unsigned, actual_bits);
@@ -10542,7 +10543,7 @@ fn intCast(
                 try sema.addSafetyCheck(block, src, ok, .cast_truncated_data);
             } else {
                 const ok = if (is_vector) ok: {
-                    const is_in_range = try block.addCmpVector(diff, dest_max, .lte);
+                    const is_in_range = try block.addCmpVector(operand, dest_max, .lte);
                     const all_in_range = try block.addInst(.{
                         .tag = if (block.float_mode == .optimized) .reduce_optimized else .reduce,
                         .data = .{ .reduce = .{
@@ -10552,7 +10553,7 @@ fn intCast(
                     });
                     break :ok all_in_range;
                 } else ok: {
-                    const is_in_range = try block.addBinOp(.cmp_lte, diff, dest_max);
+                    const is_in_range = try block.addBinOp(.cmp_lte, operand, dest_max);
                     break :ok is_in_range;
                 };
                 try sema.addSafetyCheck(block, src, ok, .cast_truncated_data);