Commit 335ff5a5f4

Robin Voetter <robin@voetter.nl>
2024-03-17 18:18:35
spirv: fix optional comparison
1 parent 8ed1342
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -3307,35 +3307,76 @@ const DeclGen = struct {
                 else
                     try self.convertToDirect(Type.bool, rhs_id);
 
-                const valid_cmp_id = try self.cmp(op, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
                 if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) {
-                    return valid_cmp_id;
+                    return try self.cmp(op, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
                 }
 
-                // TODO: Should we short circuit here? It shouldn't affect correctness, but
-                // perhaps it will generate more efficient code.
+                // a = lhs_valid
+                // b = rhs_valid
+                // c = lhs_pl == rhs_pl
+                //
+                // For op == .eq we have:
+                //   a == b && a -> c
+                // = a == b && (!a || c)
+                //
+                // For op == .neq we have
+                //   a == b && a -> c
+                // = !(a == b && a -> c)
+                // = a != b || !(a -> c
+                // = a != b || !(!a || c)
+                // = a != b || a && !c
 
                 const lhs_pl_id = try self.extractField(payload_ty, lhs_id, 0);
                 const rhs_pl_id = try self.extractField(payload_ty, rhs_id, 0);
 
-                const pl_cmp_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id);
-
-                // op == .eq  => lhs_valid == rhs_valid && lhs_pl == rhs_pl
-                // op == .neq => lhs_valid != rhs_valid || lhs_pl != rhs_pl
-
-                const result_id = self.spv.allocId();
-                const args = .{
-                    .id_result_type = self.typeId(bool_ty_ref),
-                    .id_result = result_id,
-                    .operand_1 = valid_cmp_id,
-                    .operand_2 = pl_cmp_id,
-                };
                 switch (op) {
-                    .eq => try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, args),
-                    .neq => try self.func.body.emit(self.spv.gpa, .OpLogicalOr, args),
+                    .eq => {
+                        const valid_eq_id = try self.cmp(.eq, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
+                        const pl_eq_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id);
+                        const lhs_not_valid_id = self.spv.allocId();
+                        try self.func.body.emit(self.spv.gpa, .OpLogicalNot, .{
+                            .id_result_type = self.typeId(bool_ty_ref),
+                            .id_result = lhs_not_valid_id,
+                            .operand = lhs_valid_id,
+                        });
+                        const impl_id = self.spv.allocId();
+                        try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
+                            .id_result_type = self.typeId(bool_ty_ref),
+                            .id_result = impl_id,
+                            .operand_1 = lhs_not_valid_id,
+                            .operand_2 = pl_eq_id,
+                        });
+                        const result_id = self.spv.allocId();
+                        try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, .{
+                            .id_result_type = self.typeId(bool_ty_ref),
+                            .id_result = result_id,
+                            .operand_1 = valid_eq_id,
+                            .operand_2 = impl_id,
+                        });
+                        return result_id;
+                    },
+                    .neq => {
+                        const valid_neq_id = try self.cmp(.neq, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
+                        const pl_neq_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id);
+
+                        const impl_id = self.spv.allocId();
+                        try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, .{
+                            .id_result_type = self.typeId(bool_ty_ref),
+                            .id_result = impl_id,
+                            .operand_1 = lhs_valid_id,
+                            .operand_2 = pl_neq_id,
+                        });
+                        const result_id = self.spv.allocId();
+                        try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
+                            .id_result_type = self.typeId(bool_ty_ref),
+                            .id_result = result_id,
+                            .operand_1 = valid_neq_id,
+                            .operand_2 = impl_id,
+                        });
+                        return result_id;
+                    },
                     else => unreachable,
                 }
-                return result_id;
             },
             .Vector => {
                 var wip = try self.elementWise(result_ty, true);