Commit e5c439a16d

Jacob Young <jacobly0@users.noreply.github.com>
2024-02-17 01:27:19
x86_64: implement optional comparisons
Closes #18959
1 parent a76d8ca
Changed files (3)
src
arch
codegen
test
behavior
src/arch/x86_64/CodeGen.zig
@@ -12396,9 +12396,36 @@ fn airRetLoad(self: *Self, inst: Air.Inst.Index) !void {
 fn airCmp(self: *Self, inst: Air.Inst.Index, op: math.CompareOperator) !void {
     const mod = self.bin_file.comp.module.?;
     const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
-    const ty = self.typeOf(bin_op.lhs);
+    var ty = self.typeOf(bin_op.lhs);
+    var null_compare: ?Mir.Inst.Index = null;
 
     const result: Condition = result: {
+        try self.spillEflagsIfOccupied();
+
+        const lhs_mcv = try self.resolveInst(bin_op.lhs);
+        const lhs_locks: [2]?RegisterLock = switch (lhs_mcv) {
+            .register => |lhs_reg| .{ self.register_manager.lockRegAssumeUnused(lhs_reg), null },
+            .register_pair => |lhs_regs| locks: {
+                const locks = self.register_manager.lockRegsAssumeUnused(2, lhs_regs);
+                break :locks .{ locks[0], locks[1] };
+            },
+            .register_offset => |lhs_ro| .{
+                self.register_manager.lockRegAssumeUnused(lhs_ro.reg),
+                null,
+            },
+            else => .{null} ** 2,
+        };
+        defer for (lhs_locks) |lhs_lock| if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
+
+        const rhs_mcv = try self.resolveInst(bin_op.rhs);
+        const rhs_locks: [2]?RegisterLock = switch (rhs_mcv) {
+            .register => |rhs_reg| .{ self.register_manager.lockReg(rhs_reg), null },
+            .register_pair => |rhs_regs| self.register_manager.lockRegs(2, rhs_regs),
+            .register_offset => |rhs_ro| .{ self.register_manager.lockReg(rhs_ro.reg), null },
+            else => .{null} ** 2,
+        };
+        defer for (rhs_locks) |rhs_lock| if (rhs_lock) |lock| self.register_manager.unlockReg(lock);
+
         switch (ty.zigTypeTag(mod)) {
             .Float => {
                 const float_bits = ty.floatBits(self.target.*);
@@ -12435,34 +12462,66 @@ fn airCmp(self: *Self, inst: Air.Inst.Index, op: math.CompareOperator) !void {
                     };
                 }
             },
-            else => {},
-        }
+            .Optional => if (!ty.optionalReprIsPayload(mod)) {
+                const opt_ty = ty;
+                const opt_abi_size: u31 = @intCast(opt_ty.abiSize(mod));
+                ty = opt_ty.optionalChild(mod);
+                const payload_abi_size: u31 = @intCast(ty.abiSize(mod));
 
-        try self.spillEflagsIfOccupied();
+                const temp_lhs_reg = try self.register_manager.allocReg(null, abi.RegisterClass.gp);
+                const temp_lhs_lock = self.register_manager.lockRegAssumeUnused(temp_lhs_reg);
+                defer self.register_manager.unlockReg(temp_lhs_lock);
 
-        const lhs_mcv = try self.resolveInst(bin_op.lhs);
-        const lhs_locks: [2]?RegisterLock = switch (lhs_mcv) {
-            .register => |lhs_reg| .{ self.register_manager.lockRegAssumeUnused(lhs_reg), null },
-            .register_pair => |lhs_regs| locks: {
-                const locks = self.register_manager.lockRegsAssumeUnused(2, lhs_regs);
-                break :locks .{ locks[0], locks[1] };
-            },
-            .register_offset => |lhs_ro| .{
-                self.register_manager.lockRegAssumeUnused(lhs_ro.reg),
-                null,
-            },
-            else => .{null} ** 2,
-        };
-        defer for (lhs_locks) |lhs_lock| if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
+                if (lhs_mcv.isMemory()) try self.asmRegisterMemory(
+                    .{ ._, .mov },
+                    temp_lhs_reg.to8(),
+                    try lhs_mcv.address().offset(payload_abi_size).deref().mem(self, .byte),
+                ) else {
+                    try self.genSetReg(temp_lhs_reg, opt_ty, lhs_mcv, .{});
+                    try self.asmRegisterImmediate(
+                        .{ ._r, .sh },
+                        registerAlias(temp_lhs_reg, opt_abi_size),
+                        Immediate.u(payload_abi_size * 8),
+                    );
+                }
 
-        const rhs_mcv = try self.resolveInst(bin_op.rhs);
-        const rhs_locks: [2]?RegisterLock = switch (rhs_mcv) {
-            .register => |rhs_reg| .{ self.register_manager.lockReg(rhs_reg), null },
-            .register_pair => |rhs_regs| self.register_manager.lockRegs(2, rhs_regs),
-            .register_offset => |rhs_ro| .{ self.register_manager.lockReg(rhs_ro.reg), null },
-            else => .{null} ** 2,
-        };
-        defer for (rhs_locks) |rhs_lock| if (rhs_lock) |lock| self.register_manager.unlockReg(lock);
+                const payload_compare = payload_compare: {
+                    if (rhs_mcv.isMemory()) {
+                        const rhs_mem =
+                            try rhs_mcv.address().offset(payload_abi_size).deref().mem(self, .byte);
+                        try self.asmMemoryRegister(.{ ._, .@"test" }, rhs_mem, temp_lhs_reg.to8());
+                        const payload_compare = try self.asmJccReloc(.nz, undefined);
+                        try self.asmRegisterMemory(.{ ._, .cmp }, temp_lhs_reg.to8(), rhs_mem);
+                        break :payload_compare payload_compare;
+                    }
+
+                    const temp_rhs_reg = try self.copyToTmpRegister(opt_ty, rhs_mcv);
+                    const temp_rhs_lock = self.register_manager.lockRegAssumeUnused(temp_rhs_reg);
+                    defer self.register_manager.unlockReg(temp_rhs_lock);
+
+                    try self.asmRegisterImmediate(
+                        .{ ._r, .sh },
+                        registerAlias(temp_rhs_reg, opt_abi_size),
+                        Immediate.u(payload_abi_size * 8),
+                    );
+                    try self.asmRegisterRegister(
+                        .{ ._, .@"test" },
+                        temp_lhs_reg.to8(),
+                        temp_rhs_reg.to8(),
+                    );
+                    const payload_compare = try self.asmJccReloc(.nz, undefined);
+                    try self.asmRegisterRegister(
+                        .{ ._, .cmp },
+                        temp_lhs_reg.to8(),
+                        temp_rhs_reg.to8(),
+                    );
+                    break :payload_compare payload_compare;
+                };
+                null_compare = try self.asmJmpReloc(undefined);
+                self.performReloc(payload_compare);
+            },
+            else => {},
+        }
 
         switch (ty.zigTypeTag(mod)) {
             else => {
@@ -12775,6 +12834,7 @@ fn airCmp(self: *Self, inst: Air.Inst.Index, op: math.CompareOperator) !void {
         }
     };
 
+    if (null_compare) |reloc| self.performReloc(reloc);
     self.eflags_inst = inst;
     return self.finishAir(inst, .{ .eflags = result }, .{ bin_op.lhs, bin_op.rhs, .none });
 }
src/codegen/c.zig
@@ -4140,9 +4140,7 @@ fn airCmpOp(
     if (need_cast) try writer.writeAll("(void*)");
     try f.writeCValue(writer, lhs, .Other);
     try v.elem(f, writer);
-    try writer.writeByte(' ');
     try writer.writeAll(compareOperatorC(operator));
-    try writer.writeByte(' ');
     if (need_cast) try writer.writeAll("(void*)");
     try f.writeCValue(writer, rhs, .Other);
     try v.elem(f, writer);
@@ -4181,41 +4179,28 @@ fn airEquality(
     const writer = f.object.writer();
     const inst_ty = f.typeOfIndex(inst);
     const local = try f.allocLocal(inst, inst_ty);
+    const a = try Assignment.start(f, writer, inst_ty);
     try f.writeCValue(writer, local, .Other);
-    try writer.writeAll(" = ");
+    try a.assign(f, writer);
 
     if (operand_ty.zigTypeTag(mod) == .Optional and !operand_ty.optionalReprIsPayload(mod)) {
-        // (A && B)  || (C && (A == B))
-        // A = lhs.is_null  ;  B = rhs.is_null  ;  C = rhs.payload == lhs.payload
-
-        switch (operator) {
-            .eq => {},
-            .neq => try writer.writeByte('!'),
-            else => unreachable,
-        }
-        try writer.writeAll("((");
-        try f.writeCValue(writer, lhs, .Other);
-        try writer.writeAll(".is_null && ");
-        try f.writeCValue(writer, rhs, .Other);
-        try writer.writeAll(".is_null) || (");
-        try f.writeCValue(writer, lhs, .Other);
-        try writer.writeAll(".payload == ");
-        try f.writeCValue(writer, rhs, .Other);
-        try writer.writeAll(".payload && ");
+        try f.writeCValueMember(writer, lhs, .{ .identifier = "is_null" });
+        try writer.writeAll(" || ");
+        try f.writeCValueMember(writer, rhs, .{ .identifier = "is_null" });
+        try writer.writeAll(" ? ");
+        try f.writeCValueMember(writer, lhs, .{ .identifier = "is_null" });
+        try writer.writeAll(compareOperatorC(operator));
+        try f.writeCValueMember(writer, rhs, .{ .identifier = "is_null" });
+        try writer.writeAll(" : ");
+        try f.writeCValueMember(writer, lhs, .{ .identifier = "payload" });
+        try writer.writeAll(compareOperatorC(operator));
+        try f.writeCValueMember(writer, rhs, .{ .identifier = "payload" });
+    } else {
         try f.writeCValue(writer, lhs, .Other);
-        try writer.writeAll(".is_null == ");
+        try writer.writeAll(compareOperatorC(operator));
         try f.writeCValue(writer, rhs, .Other);
-        try writer.writeAll(".is_null));\n");
-
-        return local;
     }
-
-    try f.writeCValue(writer, lhs, .Other);
-    try writer.writeByte(' ');
-    try writer.writeAll(compareOperatorC(operator));
-    try writer.writeByte(' ');
-    try f.writeCValue(writer, rhs, .Other);
-    try writer.writeAll(";\n");
+    try a.end(f, writer);
 
     return local;
 }
@@ -6322,7 +6307,7 @@ fn airCmpBuiltinCall(
     try v.elem(f, writer);
     try f.object.dg.renderBuiltinInfo(writer, scalar_ty, info);
     try writer.writeByte(')');
-    if (!ref_ret) try writer.print(" {s} {}", .{
+    if (!ref_ret) try writer.print("{s}{}", .{
         compareOperatorC(operator),
         try f.fmtIntLiteral(Type.i32, try mod.intValue(Type.i32, 0)),
     });
@@ -7668,12 +7653,12 @@ fn compareOperatorAbbrev(operator: std.math.CompareOperator) []const u8 {
 
 fn compareOperatorC(operator: std.math.CompareOperator) []const u8 {
     return switch (operator) {
-        .lt => "<",
-        .lte => "<=",
-        .eq => "==",
-        .gte => ">=",
-        .gt => ">",
-        .neq => "!=",
+        .lt => " < ",
+        .lte => " <= ",
+        .eq => " == ",
+        .gte => " >= ",
+        .gt => " > ",
+        .neq => " != ",
     };
 }
 
test/behavior/optional.zig
@@ -110,44 +110,89 @@ test "nested optional field in struct" {
     try expect(s.x.?.y == 127);
 }
 
-test "equality compare optional with non-optional" {
+test "equality compare optionals and non-optionals" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
-    try test_cmp_optional_non_optional();
-    try comptime test_cmp_optional_non_optional();
+    const S = struct {
+        fn doTheTest() !void {
+            var five: isize = 5;
+            var ten: isize = 10;
+            var opt_null: ?isize = null;
+            var opt_ten: ?isize = 10;
+            _ = .{ &five, &ten, &opt_null, &opt_ten };
+            try expect(opt_null != five);
+            try expect(opt_null != ten);
+            try expect(opt_ten != five);
+            try expect(opt_ten == ten);
+
+            var opt_int: ?isize = null;
+            try expect(opt_int != five);
+            try expect(opt_int != ten);
+            try expect(opt_int == opt_null);
+            try expect(opt_int != opt_ten);
+
+            opt_int = 10;
+            try expect(opt_int != five);
+            try expect(opt_int == ten);
+            try expect(opt_int != opt_null);
+            try expect(opt_int == opt_ten);
+
+            opt_int = five;
+            try expect(opt_int == five);
+            try expect(opt_int != ten);
+            try expect(opt_int != opt_null);
+            try expect(opt_int != opt_ten);
+
+            // test evaluation is always lexical
+            // ensure that the optional isn't always computed before the non-optional
+            var mutable_state: i32 = 0;
+            _ = blk1: {
+                mutable_state += 1;
+                break :blk1 @as(?f64, 10.0);
+            } != blk2: {
+                try expect(mutable_state == 1);
+                break :blk2 @as(f64, 5.0);
+            };
+            _ = blk1: {
+                mutable_state += 1;
+                break :blk1 @as(f64, 10.0);
+            } != blk2: {
+                try expect(mutable_state == 2);
+                break :blk2 @as(?f64, 5.0);
+            };
+        }
+    };
+
+    try S.doTheTest();
+    try comptime S.doTheTest();
 }
 
-fn test_cmp_optional_non_optional() !void {
-    var ten: i32 = 10;
-    var opt_ten: ?i32 = 10;
-    var five: i32 = 5;
-    var int_n: ?i32 = null;
-
-    _ = .{ &ten, &opt_ten, &five, &int_n };
-
-    try expect(int_n != ten);
-    try expect(opt_ten == ten);
-    try expect(opt_ten != five);
-
-    // test evaluation is always lexical
-    // ensure that the optional isn't always computed before the non-optional
-    var mutable_state: i32 = 0;
-    _ = blk1: {
-        mutable_state += 1;
-        break :blk1 @as(?f64, 10.0);
-    } != blk2: {
-        try expect(mutable_state == 1);
-        break :blk2 @as(f64, 5.0);
-    };
-    _ = blk1: {
-        mutable_state += 1;
-        break :blk1 @as(f64, 10.0);
-    } != blk2: {
-        try expect(mutable_state == 2);
-        break :blk2 @as(?f64, 5.0);
-    };
+test "compare optionals with modified payloads" {
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
+
+    var lhs: ?bool = false;
+    const lhs_payload = &lhs.?;
+    var rhs: ?bool = true;
+    const rhs_payload = &rhs.?;
+    try expect(lhs != rhs and !(lhs == rhs));
+
+    lhs = null;
+    lhs_payload.* = false;
+    rhs = false;
+    try expect(lhs != rhs and !(lhs == rhs));
+
+    lhs = true;
+    rhs = null;
+    rhs_payload.* = true;
+    try expect(lhs != rhs and !(lhs == rhs));
+
+    lhs = null;
+    lhs_payload.* = false;
+    rhs = null;
+    rhs_payload.* = true;
+    try expect(lhs == rhs and !(lhs != rhs));
 }
 
 test "unwrap function call with optional pointer return value" {