Commit 083b7b483e

David Rubin <daviru007@icloud.com>
2024-05-12 21:24:59
riscv: zero registers when using register-wide operations
what was happening is that instructions like `lb` were only affecting the lower bytes of the register and leaving the top dirty. this would lead to situtations were `cmp_eq` for example was using `xor`, which was failing because of the left-over stuff in the top of the register. with this commit, we now zero out or truncate depending on the context, to ensure instructions like xor will provide proper results.
1 parent b679956
src/arch/riscv64/bits.zig
@@ -102,7 +102,7 @@ pub const Memory = struct {
 
 pub const Immediate = union(enum) {
     signed: i32,
-    unsigned: u32,
+    unsigned: u64,
 
     pub fn u(x: u64) Immediate {
         return .{ .unsigned = x };
src/arch/riscv64/CodeGen.zig
@@ -1668,12 +1668,66 @@ fn allocReg(self: *Self, reg_class: abi.RegisterClass) !struct { Register, Regis
     return .{ reg, lock };
 }
 
+const PromoteOptions = struct {
+    /// zeroes out the register before loading in the operand
+    ///
+    /// if the operand is already a register, it will truncate with 0
+    zero: bool = false,
+};
+
 /// Similar to `allocReg` but will copy the MCValue into the Register unless `operand` is already
 /// a register, in which case it will return a possible lock to that register.
-fn promoteReg(self: *Self, ty: Type, operand: MCValue) !struct { Register, ?RegisterLock } {
-    if (operand == .register) return .{ operand.register, self.register_manager.lockReg(operand.register) };
+fn promoteReg(self: *Self, ty: Type, operand: MCValue, options: PromoteOptions) !struct { Register, ?RegisterLock } {
+    const zcu = self.bin_file.comp.module.?;
+    const bit_size = ty.bitSize(zcu);
+
+    if (operand == .register) {
+        const op_reg = operand.register;
+        if (options.zero and op_reg.class() == .int) {
+            // we make sure to emit the truncate manually because binOp will call this function
+            // and it could cause an infinite loop
+
+            _ = try self.addInst(.{
+                .tag = .slli,
+                .ops = .rri,
+                .data = .{
+                    .i_type = .{
+                        .imm12 = Immediate.u(64 - bit_size),
+                        .rd = op_reg,
+                        .rs1 = op_reg,
+                    },
+                },
+            });
+
+            _ = try self.addInst(.{
+                .tag = .srli,
+                .ops = .rri,
+                .data = .{
+                    .i_type = .{
+                        .imm12 = Immediate.u(64 - bit_size),
+                        .rd = op_reg,
+                        .rs1 = op_reg,
+                    },
+                },
+            });
+        }
+
+        return .{ op_reg, self.register_manager.lockReg(operand.register) };
+    }
 
     const reg, const lock = try self.allocReg(self.typeRegClass(ty));
+
+    if (options.zero and reg.class() == .int) {
+        _ = try self.addInst(.{
+            .tag = .pseudo,
+            .ops = .pseudo_mv,
+            .data = .{ .rr = .{
+                .rd = reg,
+                .rs = .zero,
+            } },
+        });
+    }
+
     try self.genSetReg(ty, reg, operand);
     return .{ reg, lock };
 }
@@ -2124,10 +2178,10 @@ fn binOpRegister(
     rhs: MCValue,
     rhs_ty: Type,
 ) !MCValue {
-    const lhs_reg, const lhs_lock = try self.promoteReg(lhs_ty, lhs);
+    const lhs_reg, const lhs_lock = try self.promoteReg(lhs_ty, lhs, .{ .zero = true });
     defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
 
-    const rhs_reg, const rhs_lock = try self.promoteReg(rhs_ty, rhs);
+    const rhs_reg, const rhs_lock = try self.promoteReg(rhs_ty, rhs, .{ .zero = true });
     defer if (rhs_lock) |lock| self.register_manager.unlockReg(lock);
 
     const dest_reg, const dest_lock = try self.allocReg(.int);
@@ -2223,10 +2277,10 @@ fn binOpFloat(
     const zcu = self.bin_file.comp.module.?;
     const float_bits = lhs_ty.floatBits(zcu.getTarget());
 
-    const lhs_reg, const lhs_lock = try self.promoteReg(lhs_ty, lhs);
+    const lhs_reg, const lhs_lock = try self.promoteReg(lhs_ty, lhs, .{});
     defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
 
-    const rhs_reg, const rhs_lock = try self.promoteReg(rhs_ty, rhs);
+    const rhs_reg, const rhs_lock = try self.promoteReg(rhs_ty, rhs, .{});
     defer if (rhs_lock) |lock| self.register_manager.unlockReg(lock);
 
     const mir_tag: Mir.Inst.Tag = switch (tag) {
@@ -2425,10 +2479,10 @@ fn airSubWithOverflow(self: *Self, inst: Air.Inst.Index) !void {
         const result_mcv = try self.allocRegOrMem(inst, false);
         const offset = result_mcv.load_frame;
 
-        const lhs_reg, const lhs_lock = try self.promoteReg(lhs_ty, lhs);
+        const lhs_reg, const lhs_lock = try self.promoteReg(lhs_ty, lhs, .{});
         defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
 
-        const rhs_reg, const rhs_lock = try self.promoteReg(rhs_ty, rhs);
+        const rhs_reg, const rhs_lock = try self.promoteReg(rhs_ty, rhs, .{});
         defer if (rhs_lock) |lock| self.register_manager.unlockReg(lock);
 
         const dest_reg, const dest_lock = try self.allocReg(.int);
@@ -2559,7 +2613,7 @@ fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) !void {
                                         1...8 => {
                                             const max_val = std.math.pow(u16, 2, int_info.bits) - 1;
 
-                                            const add_reg, const add_lock = try self.promoteReg(lhs_ty, lhs);
+                                            const add_reg, const add_lock = try self.promoteReg(lhs_ty, lhs, .{});
                                             defer if (add_lock) |lock| self.register_manager.unlockReg(lock);
 
                                             const overflow_reg, const overflow_lock = try self.allocReg(.int);
@@ -2645,10 +2699,10 @@ fn airBitAnd(self: *Self, inst: Air.Inst.Index) !void {
         const lhs_ty = self.typeOf(bin_op.lhs);
         const rhs_ty = self.typeOf(bin_op.rhs);
 
-        const lhs_reg, const lhs_lock = try self.promoteReg(lhs_ty, lhs);
+        const lhs_reg, const lhs_lock = try self.promoteReg(lhs_ty, lhs, .{});
         defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
 
-        const rhs_reg, const rhs_lock = try self.promoteReg(rhs_ty, rhs);
+        const rhs_reg, const rhs_lock = try self.promoteReg(rhs_ty, rhs, .{});
         defer if (rhs_lock) |lock| self.register_manager.unlockReg(lock);
 
         const dest_reg, const dest_lock = try self.allocReg(.int);
@@ -2678,10 +2732,10 @@ fn airBitOr(self: *Self, inst: Air.Inst.Index) !void {
         const lhs_ty = self.typeOf(bin_op.lhs);
         const rhs_ty = self.typeOf(bin_op.rhs);
 
-        const lhs_reg, const lhs_lock = try self.promoteReg(lhs_ty, lhs);
+        const lhs_reg, const lhs_lock = try self.promoteReg(lhs_ty, lhs, .{});
         defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
 
-        const rhs_reg, const rhs_lock = try self.promoteReg(rhs_ty, rhs);
+        const rhs_reg, const rhs_lock = try self.promoteReg(rhs_ty, rhs, .{});
         defer if (rhs_lock) |lock| self.register_manager.unlockReg(lock);
 
         const dest_reg, const dest_lock = try self.allocReg(.int);
@@ -4706,10 +4760,10 @@ fn airBoolOp(self: *Self, inst: Air.Inst.Index) !void {
         const lhs_ty = Type.bool;
         const rhs_ty = Type.bool;
 
-        const lhs_reg, const lhs_lock = try self.promoteReg(lhs_ty, lhs);
+        const lhs_reg, const lhs_lock = try self.promoteReg(lhs_ty, lhs, .{});
         defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
 
-        const rhs_reg, const rhs_lock = try self.promoteReg(rhs_ty, rhs);
+        const rhs_reg, const rhs_lock = try self.promoteReg(rhs_ty, rhs, .{});
         defer if (rhs_lock) |lock| self.register_manager.unlockReg(lock);
 
         const result_reg, const result_lock = try self.allocReg(.int);
@@ -4905,7 +4959,7 @@ fn genCopy(self: *Self, ty: Type, dst_mcv: MCValue, src_mcv: MCValue) !void {
             const src_info: ?struct { addr_reg: Register, addr_lock: ?RegisterLock } = switch (src_mcv) {
                 .register_pair, .memory, .indirect, .load_frame => null,
                 .load_symbol => src: {
-                    const src_addr_reg, const src_addr_lock = try self.promoteReg(Type.usize, src_mcv.address());
+                    const src_addr_reg, const src_addr_lock = try self.promoteReg(Type.usize, src_mcv.address(), .{});
                     errdefer self.register_manager.unlockReg(src_addr_lock);
 
                     break :src .{ .addr_reg = src_addr_reg, .addr_lock = src_addr_lock };
@@ -5463,7 +5517,7 @@ fn genSetMem(
         .immediate => {
             // TODO: remove this lock in favor of a copyToTmpRegister when we load 64 bit immediates with
             // a register allocation.
-            const reg, const reg_lock = try self.promoteReg(ty, src_mcv);
+            const reg, const reg_lock = try self.promoteReg(ty, src_mcv, .{});
             defer if (reg_lock) |lock| self.register_manager.unlockReg(lock);
 
             return self.genSetMem(base, disp, ty, .{ .register = reg });
test/behavior/array.zig
@@ -542,7 +542,6 @@ test "sentinel element count towards the ABI size calculation" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
test/behavior/bitcast.zig
@@ -165,7 +165,6 @@ test "@bitCast packed structs at runtime and comptime" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const Full = packed struct {
         number: u16,
@@ -192,7 +191,6 @@ test "@bitCast packed structs at runtime and comptime" {
 test "@bitCast extern structs at runtime and comptime" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const Full = extern struct {
         number: u16,
@@ -227,7 +225,6 @@ test "bitcast packed struct to integer and back" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const LevelUpMove = packed struct {
         move_id: u9,
test/behavior/cast.zig
@@ -57,8 +57,6 @@ test "@intCast to comptime_int" {
 }
 
 test "implicit cast comptime numbers to any type when the value fits" {
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
-
     const a: u64 = 255;
     var b: u8 = a;
     _ = &b;
test/behavior/error.zig
@@ -740,7 +740,6 @@ test "ret_ptr doesn't cause own inferred error set to be resolved" {
 
 test "simple else prong allowed even when all errors handled" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const S = struct {
         fn foo() !u8 {
test/behavior/eval.zig
@@ -395,7 +395,6 @@ test "return 0 from function that has u0 return type" {
 test "statically initialized struct" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     st_init_str_foo.x += 1;
     try expect(st_init_str_foo.x == 14);
@@ -787,7 +786,6 @@ test "array concatenation peer resolves element types - pointer" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     var a = [2]u3{ 1, 7 };
     var b = [3]u8{ 200, 225, 255 };
test/behavior/math.zig
@@ -605,8 +605,6 @@ fn testSignedNegationWrappingEval(x: i16) !void {
 }
 
 test "unsigned negation wrapping" {
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
-
     try testUnsignedNegationWrappingEval(1);
     try comptime testUnsignedNegationWrappingEval(1);
 }
@@ -1436,8 +1434,6 @@ test "quad hex float literal parsing accurate" {
 }
 
 test "truncating shift left" {
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
-
     try testShlTrunc(maxInt(u16));
     try comptime testShlTrunc(maxInt(u16));
 }
test/behavior/packed-struct.zig
@@ -258,7 +258,6 @@ test "nested packed struct unaligned" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
     if (native_endian != .little) return error.SkipZigTest; // Byte aligned packed struct field pointers have not been implemented yet
 
     const S1 = packed struct {
@@ -331,7 +330,6 @@ test "byte-aligned field pointer offsets" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const S = struct {
         const A = packed struct {
test/behavior/reflection.zig
@@ -28,7 +28,6 @@ fn dummy(a: bool, b: i32, c: f32) i32 {
 test "reflection: @field" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     var f = Foo{
         .one = 42,
test/behavior/struct.zig
@@ -68,7 +68,6 @@ const SmallStruct = struct {
 
 test "lower unnamed constants" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     var foo = SmallStruct{ .a = 1, .b = 255 };
     try expect(foo.first() == 1);
@@ -395,7 +394,6 @@ test "packed struct" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     var foo = APackedStruct{
         .x = 1,
@@ -876,7 +874,6 @@ test "packed struct field passed to generic function" {
 test "anonymous struct literal syntax" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const S = struct {
         const Point = struct {
@@ -1106,7 +1103,6 @@ test "packed struct with undefined initializers" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const S = struct {
         const P = packed struct {
@@ -1369,7 +1365,6 @@ test "store to comptime field" {
 test "struct field init value is size of the struct" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const namespace = struct {
         const S = extern struct {
test/behavior/this.zig
@@ -27,7 +27,6 @@ test "this refer to module call private fn" {
 test "this refer to container" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     var pt: Point(i32) = undefined;
     pt.x = 12;
test/behavior/union.zig
@@ -2025,7 +2025,6 @@ test "inner struct initializer uses packed union layout" {
 
 test "extern union initialized via reintepreted struct field initializer" {
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const bytes = [_]u8{ 0xaa, 0xbb, 0xcc, 0xdd };
 
@@ -2045,7 +2044,6 @@ test "extern union initialized via reintepreted struct field initializer" {
 
 test "packed union initialized via reintepreted struct field initializer" {
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const bytes = [_]u8{ 0xaa, 0xbb, 0xcc, 0xdd };
 
@@ -2066,7 +2064,6 @@ test "packed union initialized via reintepreted struct field initializer" {
 
 test "store of comptime reinterpreted memory to extern union" {
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const bytes = [_]u8{ 0xaa, 0xbb, 0xcc, 0xdd };
 
@@ -2089,7 +2086,6 @@ test "store of comptime reinterpreted memory to extern union" {
 
 test "store of comptime reinterpreted memory to packed union" {
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const bytes = [_]u8{ 0xaa, 0xbb, 0xcc, 0xdd };