Commit f4064d98e2

Robin Voetter <robin@voetter.nl>
2023-10-12 22:09:34
spirv: optional comparison
1 parent 10b8171
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -52,12 +52,6 @@ const Block = struct {
 
 const BlockMap = std.AutoHashMapUnmanaged(Air.Inst.Index, *Block);
 
-/// Maps Zig decl indices to SPIR-V linking information.
-pub const DeclLinkMap = std.AutoHashMapUnmanaged(Decl.Index, SpvModule.Decl.Index);
-
-/// Maps anon decl indices to SPIR-V linking information.
-pub const AnonDeclLinkMap = std.AutoHashMapUnmanaged(struct { InternPool.Index, StorageClass }, SpvModule.Decl.Index);
-
 /// This structure holds information that is relevant to the entire compilation,
 /// in contrast to `DeclGen`, which only holds relevant information about a
 /// single decl.
@@ -70,10 +64,10 @@ pub const Object = struct {
 
     /// The Zig module that this object file is generated for.
     /// A map of Zig decl indices to SPIR-V decl indices.
-    decl_link: DeclLinkMap = .{},
+    decl_link: std.AutoHashMapUnmanaged(Decl.Index, SpvModule.Decl.Index) = .{},
 
     /// A map of Zig InternPool indices for anonymous decls to SPIR-V decl indices.
-    anon_decl_link: AnonDeclLinkMap = .{},
+    anon_decl_link: std.AutoHashMapUnmanaged(struct { InternPool.Index, StorageClass }, SpvModule.Decl.Index) = .{},
 
     /// A map that maps AIR intern pool indices to SPIR-V cache references (which
     /// is basically the same thing except for SPIR-V).
@@ -1266,22 +1260,32 @@ const DeclGen = struct {
 
                 const elem_ty = ty.childType(mod);
                 const elem_ty_ref = try self.resolveType(elem_ty, .indirect);
-                const total_len = std.math.cast(u32, ty.arrayLenIncludingSentinel(mod)) orelse {
+                var total_len = std.math.cast(u32, ty.arrayLenIncludingSentinel(mod)) orelse {
                     return self.fail("array type of {} elements is too large", .{ty.arrayLenIncludingSentinel(mod)});
                 };
-                if (!ty.hasRuntimeBitsIgnoreComptime(mod)) {
+                const ty_ref = if (!elem_ty.hasRuntimeBitsIgnoreComptime(mod)) blk: {
                     // The size of the array would be 0, but that is not allowed in SPIR-V.
-                    // This path can be reached for example when there is a slicing of a pointer
-                    // that produces a zero-length array. In all cases where this type can be generated,
-                    // we should be in an indirect path (direct uses of this type should be filtered out in Sema).
+                    // This path can be reached when the backend is asked to generate a pointer to
+                    // an array of some zero-bit type. This should always be an indirect path.
                     assert(repr == .indirect);
 
-                    return try self.spv.resolve(.{ .opaque_type = .{
+                    // We cannot use the child type here, so just use an opaque type.
+                    break :blk try self.spv.resolve(.{ .opaque_type = .{
                         .name = try self.spv.resolveString("zero-sized array"),
                     } });
-                }
+                } else if (total_len == 0) blk: {
+                    // The size of the array would be 0, but that is not allowed in SPIR-V.
+                    // This path can be reached for example when there is a slicing of a pointer
+                    // that produces a zero-length array. In all cases where this type can be generated,
+                    // this should be an indirect path.
+                    assert(repr == .indirect);
+
+                    // In this case, we have an array of a non-zero sized type. In this case,
+                    // generate an array of 1 element instead, so that ptr_elem_ptr instructions
+                    // can be lowered to ptrAccessChain instead of manually performing the math.
+                    break :blk try self.spv.arrayType(1, elem_ty_ref);
+                } else try self.spv.arrayType(total_len, elem_ty_ref);
 
-                const ty_ref = try self.spv.arrayType(total_len, elem_ty_ref);
                 try self.type_map.put(self.gpa, ty.toIntern(), .{ .ty_ref = ty_ref });
                 return ty_ref;
             },
@@ -2554,38 +2558,85 @@ const DeclGen = struct {
         var cmp_lhs_id = lhs_id;
         var cmp_rhs_id = rhs_id;
         const bool_ty_ref = try self.resolveType(Type.bool, .direct);
-        const opcode: Opcode = opcode: {
-            const op_ty = switch (ty.zigTypeTag(mod)) {
-                .Int, .Bool, .Float => ty,
-                .Enum => ty.intTagType(mod),
-                .ErrorSet => Type.u16,
-                .Pointer => blk: {
-                    // Note that while SPIR-V offers OpPtrEqual and OpPtrNotEqual, they are
-                    // currently not implemented in the SPIR-V LLVM translator. Thus, we emit these using
-                    // OpConvertPtrToU...
-                    cmp_lhs_id = self.spv.allocId();
-                    cmp_rhs_id = self.spv.allocId();
-
-                    const usize_ty_id = self.typeId(try self.sizeType());
-
-                    try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{
-                        .id_result_type = usize_ty_id,
-                        .id_result = cmp_lhs_id,
-                        .pointer = lhs_id,
-                    });
+        const op_ty = switch (ty.zigTypeTag(mod)) {
+            .Int, .Bool, .Float => ty,
+            .Enum => ty.intTagType(mod),
+            .ErrorSet => Type.u16,
+            .Pointer => blk: {
+                // Note that while SPIR-V offers OpPtrEqual and OpPtrNotEqual, they are
+                // currently not implemented in the SPIR-V LLVM translator. Thus, we emit these using
+                // OpConvertPtrToU...
+                cmp_lhs_id = self.spv.allocId();
+                cmp_rhs_id = self.spv.allocId();
+
+                const usize_ty_id = self.typeId(try self.sizeType());
+
+                try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{
+                    .id_result_type = usize_ty_id,
+                    .id_result = cmp_lhs_id,
+                    .pointer = lhs_id,
+                });
 
-                    try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{
-                        .id_result_type = usize_ty_id,
-                        .id_result = cmp_rhs_id,
-                        .pointer = rhs_id,
-                    });
+                try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{
+                    .id_result_type = usize_ty_id,
+                    .id_result = cmp_rhs_id,
+                    .pointer = rhs_id,
+                });
 
-                    break :blk Type.usize;
-                },
-                .Optional => unreachable, // TODO
-                else => unreachable,
-            };
+                break :blk Type.usize;
+            },
+            .Optional => {
+                const payload_ty = ty.optionalChild(mod);
+                if (ty.optionalReprIsPayload(mod)) {
+                    assert(payload_ty.hasRuntimeBitsIgnoreComptime(mod));
+                    assert(!payload_ty.isSlice(mod));
+                    return self.cmp(op, payload_ty, lhs_id, rhs_id);
+                }
+
+                const lhs_valid_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod))
+                    try self.extractField(Type.bool, lhs_id, 1)
+                else
+                    try self.convertToDirect(Type.bool, lhs_id);
 
+                const rhs_valid_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod))
+                    try self.extractField(Type.bool, rhs_id, 1)
+                else
+                    try self.convertToDirect(Type.bool, rhs_id);
+
+                const valid_cmp_id = try self.cmp(op, Type.bool, lhs_valid_id, rhs_valid_id);
+                if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) {
+                    return valid_cmp_id;
+                }
+
+                // TODO: Should we short circuit here? It shouldn't affect correctness, but
+                // perhaps it will generate more efficient code.
+
+                const lhs_pl_id = try self.extractField(payload_ty, lhs_id, 0);
+                const rhs_pl_id = try self.extractField(payload_ty, rhs_id, 0);
+
+                const pl_cmp_id = try self.cmp(op, payload_ty, lhs_pl_id, rhs_pl_id);
+
+                // op == .eq  => lhs_valid == rhs_valid && lhs_pl == rhs_pl
+                // op == .neq => lhs_valid != rhs_valid || lhs_pl != rhs_pl
+
+                const result_id = self.spv.allocId();
+                const args = .{
+                    .id_result_type = self.typeId(bool_ty_ref),
+                    .id_result = result_id,
+                    .operand_1 = valid_cmp_id,
+                    .operand_2 = pl_cmp_id,
+                };
+                switch (op) {
+                    .eq => try self.func.body.emit(self.spv.gpa, .OpLogicalAnd, args),
+                    .neq => try self.func.body.emit(self.spv.gpa, .OpLogicalOr, args),
+                    else => unreachable,
+                }
+                return result_id;
+            },
+            else => unreachable,
+        };
+
+        const opcode: Opcode = opcode: {
             const info = try self.arithmeticTypeInfo(op_ty);
             const signedness = switch (info.class) {
                 .composite_integer => {
@@ -2653,7 +2704,6 @@ const DeclGen = struct {
         const lhs_id = try self.resolve(bin_op.lhs);
         const rhs_id = try self.resolve(bin_op.rhs);
         const ty = self.typeOf(bin_op.lhs);
-        assert(ty.eql(self.typeOf(bin_op.rhs), self.module));
 
         return try self.cmp(op, ty, lhs_id, rhs_id);
     }
@@ -3061,16 +3111,17 @@ const DeclGen = struct {
         const mod = self.module;
         const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
         const bin_op = self.air.extraData(Air.Bin, ty_pl.payload).data;
-        const ptr_ty = self.typeOf(bin_op.lhs);
-        const elem_ty = ptr_ty.childType(mod);
+        const src_ptr_ty = self.typeOf(bin_op.lhs);
+        const elem_ty = src_ptr_ty.childType(mod);
+        const ptr_id = try self.resolve(bin_op.lhs);
+
         if (!elem_ty.hasRuntimeBitsIgnoreComptime(mod)) {
-            const ptr_ty_ref = try self.resolveType(ptr_ty, .direct);
-            return try self.spv.constUndef(ptr_ty_ref);
+            const dst_ptr_ty = self.typeOfIndex(inst);
+            return try self.bitCast(dst_ptr_ty, src_ptr_ty, ptr_id);
         }
 
-        const ptr_id = try self.resolve(bin_op.lhs);
         const index_id = try self.resolve(bin_op.rhs);
-        return try self.ptrElemPtr(ptr_ty, ptr_id, index_id);
+        return try self.ptrElemPtr(src_ptr_ty, ptr_id, index_id);
     }
 
     fn airArrayElemVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {