Commit 200bca360e

Robin Voetter <robin@voetter.nl>
2023-10-21 12:53:13
spirv: replace most use of spv.ptrType with self.ptrType
To support self-referential pointers, in the future we will need to pass the Zig type to any pointer that is created. This lays some ground work for that by replacing most uses of spv.ptrType with a new ptrType function that also accepts the Zig type. This function's contents will soon be replaced by a version that also supports self-referential pointers. Also fixed some bugs regarding the use of direct/indirect.
1 parent b403ca0
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -358,8 +358,7 @@ const DeclGen = struct {
 
         const mod = self.module;
         const ty = mod.intern_pool.typeOf(val).toType();
-        const ty_ref = try self.resolveType(ty, .indirect);
-        const ptr_ty_ref = try self.spv.ptrType(ty_ref, storage_class);
+        const ptr_ty_ref = try self.ptrType(ty, storage_class);
 
         const var_id = self.spv.declPtr(spv_decl_index).result_id;
 
@@ -623,25 +622,15 @@ const DeclGen = struct {
     /// result_ty_ref must be an array type.
     /// Constituents should be in `indirect` representation (as the elements of an array should be).
     /// Result is in `direct` representation.
-    fn constructArray(self: *DeclGen, result_ty_ref: CacheRef, constituents: []const IdRef) !IdRef {
+    fn constructArray(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef {
         // The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which'
         // operands are not constant.
         // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349
         // For now, just initialize the struct by setting the fields manually...
         // TODO: Make this OpCompositeConstruct when we can
-        // TODO: Make this Function storage type
-        const ptr_ty_ref = try self.spv.ptrType(result_ty_ref, .Function);
-        const ptr_composite_id = self.spv.allocId();
-        try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
-            .id_result_type = self.typeId(ptr_ty_ref),
-            .id_result = ptr_composite_id,
-            .storage_class = .Function,
-        });
-
-        const spv_composite_ty = self.spv.cache.lookup(result_ty_ref).array_type;
-        const elem_ty_ref = spv_composite_ty.element_type;
-        const ptr_elem_ty_ref = try self.spv.ptrType(elem_ty_ref, .Function);
-
+        const mod = self.module;
+        const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function });
+        const ptr_elem_ty_ref = try self.ptrType(ty.elemType2(mod), .Function);
         for (constituents, 0..) |constitent_id, index| {
             const ptr_id = try self.accessChain(ptr_elem_ty_ref, ptr_composite_id, &.{@as(u32, @intCast(index))});
             try self.func.body.emit(self.spv.gpa, .OpStore, .{
@@ -649,13 +638,8 @@ const DeclGen = struct {
                 .object = constitent_id,
             });
         }
-        const result_id = self.spv.allocId();
-        try self.func.body.emit(self.spv.gpa, .OpLoad, .{
-            .id_result_type = self.typeId(result_ty_ref),
-            .id_result = result_id,
-            .pointer = ptr_composite_id,
-        });
-        return result_id;
+
+        return try self.load(ty, ptr_composite_id, .{});
     }
 
     /// This function generates a load for a constant in direct (ie, non-memory) representation.
@@ -857,7 +841,7 @@ const DeclGen = struct {
                         else => {},
                     }
 
-                    return try self.constructArray(result_ty_ref, constituents);
+                    return try self.constructArray(ty, constituents);
                 },
                 .struct_type => {
                     const struct_type = mod.typeToStruct(ty).?;
@@ -892,7 +876,7 @@ const DeclGen = struct {
                 const active_field = ty.unionTagFieldIndex(un.tag.toValue(), mod).?;
                 const layout = self.unionLayout(ty, active_field);
                 const payload = if (layout.active_field_size != 0)
-                    try self.constant(layout.active_field_ty, un.val.toValue(), .indirect)
+                    try self.constant(layout.active_field_ty, un.val.toValue(), .direct)
                 else
                     null;
 
@@ -934,8 +918,7 @@ const DeclGen = struct {
 
                 // TODO: Can we consolidate this in ptrElemPtr?
                 const elem_ty = parent_ptr_ty.elemType2(mod); // use elemType() so that we get T for *[N]T.
-                const elem_ty_ref = try self.resolveType(elem_ty, .direct);
-                const elem_ptr_ty_ref = try self.spv.ptrType(elem_ty_ref, spvStorageClass(parent_ptr_ty.ptrAddressSpace(mod)));
+                const elem_ptr_ty_ref = try self.ptrType(elem_ty, spvStorageClass(parent_ptr_ty.ptrAddressSpace(mod)));
 
                 if (elem_ptr_ty_ref == result_ty_ref) {
                     return elem_ptr_id;
@@ -992,8 +975,7 @@ const DeclGen = struct {
         };
 
         const decl_id = try self.resolveAnonDecl(decl_val, actual_storage_class);
-        const decl_ty_ref = try self.resolveType(decl_ty, .indirect);
-        const decl_ptr_ty_ref = try self.spv.ptrType(decl_ty_ref, final_storage_class);
+        const decl_ptr_ty_ref = try self.ptrType(decl_ty, final_storage_class);
 
         const ptr_id = switch (final_storage_class) {
             .Generic => blk: {
@@ -1049,8 +1031,7 @@ const DeclGen = struct {
 
         const final_storage_class = spvStorageClass(decl.@"addrspace");
 
-        const decl_ty_ref = try self.resolveType(decl.ty, .indirect);
-        const decl_ptr_ty_ref = try self.spv.ptrType(decl_ty_ref, final_storage_class);
+        const decl_ptr_ty_ref = try self.ptrType(decl.ty, final_storage_class);
 
         const ptr_id = switch (final_storage_class) {
             .Generic => blk: {
@@ -1118,6 +1099,12 @@ const DeclGen = struct {
         return try self.intType(.unsigned, self.getTarget().ptrBitWidth());
     }
 
+    fn ptrType(self: *DeclGen, child_ty: Type, storage_class: StorageClass) !CacheRef {
+        // TODO: This function will be rewritten so that forward declarations work properly
+        const child_ty_ref = try self.resolveType(child_ty, .indirect);
+        return try self.spv.ptrType(child_ty_ref, storage_class);
+    }
+
     /// Generate a union type, optionally with a known field. If the tag alignment is greater
     /// than that of the payload, a regular union (non-packed, with both tag and payload), will
     /// be generated as follows:
@@ -1678,7 +1665,7 @@ const DeclGen = struct {
     /// the name of an error in the text executor.
     fn generateTestEntryPoint(self: *DeclGen, name: []const u8, spv_test_decl_index: SpvModule.Decl.Index) !void {
         const anyerror_ty_ref = try self.resolveType(Type.anyerror, .direct);
-        const ptr_anyerror_ty_ref = try self.spv.ptrType(anyerror_ty_ref, .CrossWorkgroup);
+        const ptr_anyerror_ty_ref = try self.ptrType(Type.anyerror, .CrossWorkgroup);
         const void_ty_ref = try self.resolveType(Type.void, .direct);
 
         const kernel_proto_ty_ref = try self.spv.resolve(.{ .function_type = .{
@@ -1713,6 +1700,7 @@ const DeclGen = struct {
             .id_result = error_id,
             .function = test_id,
         });
+        // Note: Convert to direct not required.
         try section.emit(self.spv.gpa, .OpStore, .{
             .pointer = p_error_id,
             .object = error_id,
@@ -1817,8 +1805,7 @@ const DeclGen = struct {
                 else => final_storage_class,
             };
 
-            const ty_ref = try self.resolveType(decl.ty, .indirect);
-            const ptr_ty_ref = try self.spv.ptrType(ty_ref, actual_storage_class);
+            const ptr_ty_ref = try self.ptrType(decl.ty, actual_storage_class);
 
             const begin = self.spv.beginGlobal();
             try self.spv.globals.section.emit(self.spv.gpa, .OpVariable, .{
@@ -2113,9 +2100,7 @@ const DeclGen = struct {
                 constituent.* = try self.convertToIndirect(child_ty, result_id);
             }
 
-            const result_ty = try self.resolveType(child_ty, .indirect);
-            const result_ty_ref = try self.spv.arrayType(vector_len, result_ty);
-            return try self.constructArray(result_ty_ref, constituents);
+            return try self.constructArray(ty, constituents);
         }
 
         const result_id = self.spv.allocId();
@@ -2176,7 +2161,7 @@ const DeclGen = struct {
 
         const info = try self.arithmeticTypeInfo(result_ty);
         // TODO: Use fmin for OpenCL
-        const cmp_id = try self.cmp(op, result_ty, lhs_id, rhs_id);
+        const cmp_id = try self.cmp(op, Type.bool, result_ty, lhs_id, rhs_id);
         const selection_id = switch (info.class) {
             .float => blk: {
                 // cmp uses OpFOrd. When we have 0 [<>] nan this returns false,
@@ -2311,7 +2296,7 @@ const DeclGen = struct {
                 constituent.* = try self.arithOp(child_ty, lhs_index_id, rhs_index_id, fop, sop, uop, modular);
             }
 
-            return self.constructArray(result_ty_ref, constituents);
+            return self.constructArray(ty, constituents);
         }
 
         // Binary operations are generally applicable to both scalar and vector operations
@@ -2629,6 +2614,7 @@ const DeclGen = struct {
     fn cmp(
         self: *DeclGen,
         op: std.math.CompareOperator,
+        result_ty: Type,
         ty: Type,
         lhs_id: IdRef,
         rhs_id: IdRef,
@@ -2669,7 +2655,7 @@ const DeclGen = struct {
                 if (ty.optionalReprIsPayload(mod)) {
                     assert(payload_ty.hasRuntimeBitsIgnoreComptime(mod));
                     assert(!payload_ty.isSlice(mod));
-                    return self.cmp(op, payload_ty, lhs_id, rhs_id);
+                    return self.cmp(op, Type.bool, payload_ty, lhs_id, rhs_id);
                 }
 
                 const lhs_valid_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod))
@@ -2682,7 +2668,7 @@ const DeclGen = struct {
                 else
                     try self.convertToDirect(Type.bool, rhs_id);
 
-                const valid_cmp_id = try self.cmp(op, Type.bool, lhs_valid_id, rhs_valid_id);
+                const valid_cmp_id = try self.cmp(op, Type.bool, Type.bool, lhs_valid_id, rhs_valid_id);
                 if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) {
                     return valid_cmp_id;
                 }
@@ -2693,7 +2679,7 @@ const DeclGen = struct {
                 const lhs_pl_id = try self.extractField(payload_ty, lhs_id, 0);
                 const rhs_pl_id = try self.extractField(payload_ty, rhs_id, 0);
 
-                const pl_cmp_id = try self.cmp(op, payload_ty, lhs_pl_id, rhs_pl_id);
+                const pl_cmp_id = try self.cmp(op, Type.bool, payload_ty, lhs_pl_id, rhs_pl_id);
 
                 // op == .eq  => lhs_valid == rhs_valid && lhs_pl == rhs_pl
                 // op == .neq => lhs_valid != rhs_valid || lhs_pl != rhs_pl
@@ -2715,7 +2701,6 @@ const DeclGen = struct {
             .Vector => {
                 const child_ty = ty.childType(mod);
                 const vector_len = ty.vectorLen(mod);
-                const bool_ty_ref_indirect = try self.resolveType(Type.bool, .indirect);
 
                 var constituents = try self.gpa.alloc(IdRef, vector_len);
                 defer self.gpa.free(constituents);
@@ -2723,12 +2708,11 @@ const DeclGen = struct {
                 for (constituents, 0..) |*constituent, i| {
                     const lhs_index_id = try self.extractField(child_ty, cmp_lhs_id, @intCast(i));
                     const rhs_index_id = try self.extractField(child_ty, cmp_rhs_id, @intCast(i));
-                    const result_id = try self.cmp(op, child_ty, lhs_index_id, rhs_index_id);
+                    const result_id = try self.cmp(op, Type.bool, child_ty, lhs_index_id, rhs_index_id);
                     constituent.* = try self.convertToIndirect(Type.bool, result_id);
                 }
 
-                const result_ty_ref = try self.spv.arrayType(vector_len, bool_ty_ref_indirect);
-                return try self.constructArray(result_ty_ref, constituents);
+                return try self.constructArray(result_ty, constituents);
             },
             else => unreachable,
         };
@@ -2801,8 +2785,9 @@ const DeclGen = struct {
         const lhs_id = try self.resolve(bin_op.lhs);
         const rhs_id = try self.resolve(bin_op.rhs);
         const ty = self.typeOf(bin_op.lhs);
+        const result_ty = self.typeOfIndex(inst);
 
-        return try self.cmp(op, ty, lhs_id, rhs_id);
+        return try self.cmp(op, result_ty, ty, lhs_id, rhs_id);
     }
 
     fn airVectorCmp(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -2814,8 +2799,9 @@ const DeclGen = struct {
         const rhs_id = try self.resolve(vec_cmp.rhs);
         const op = vec_cmp.compareOperator();
         const ty = self.typeOf(vec_cmp.lhs);
+        const result_ty = self.typeOfIndex(inst);
 
-        return try self.cmp(op, ty, lhs_id, rhs_id);
+        return try self.cmp(op, result_ty, ty, lhs_id, rhs_id);
     }
 
     fn bitCast(
@@ -2860,15 +2846,9 @@ const DeclGen = struct {
             return result_id;
         }
 
-        const src_ptr_ty_ref = try self.spv.ptrType(src_ty_ref, .Function);
-        const dst_ptr_ty_ref = try self.spv.ptrType(dst_ty_ref, .Function);
+        const dst_ptr_ty_ref = try self.ptrType(dst_ty, .Function);
 
-        const tmp_id = self.spv.allocId();
-        try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
-            .id_result_type = self.typeId(src_ptr_ty_ref),
-            .id_result = tmp_id,
-            .storage_class = .Function,
-        });
+        const tmp_id = try self.alloc(src_ty, .{ .storage_class = .Function });
         try self.store(src_ty, tmp_id, src_id, false);
         const casted_ptr_id = self.spv.allocId();
         try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
@@ -3154,7 +3134,7 @@ const DeclGen = struct {
                     elem_ids[n_elems - 1] = try self.constant(array_info.elem_type, sentinel_val, .indirect);
                 }
 
-                return try self.constructArray(result_ty_ref, elem_ids);
+                return try self.constructArray(result_ty, elem_ids);
             },
             else => unreachable,
         }
@@ -3246,8 +3226,7 @@ const DeclGen = struct {
         const mod = self.module;
         // Construct new pointer type for the resulting pointer
         const elem_ty = ptr_ty.elemType2(mod); // use elemType() so that we get T for *[N]T.
-        const elem_ty_ref = try self.resolveType(elem_ty, .direct);
-        const elem_ptr_ty_ref = try self.spv.ptrType(elem_ty_ref, spvStorageClass(ptr_ty.ptrAddressSpace(mod)));
+        const elem_ptr_ty_ref = try self.ptrType(elem_ty, spvStorageClass(ptr_ty.ptrAddressSpace(mod)));
         if (ptr_ty.isSinglePointer(mod)) {
             // Pointer-to-array. In this case, the resulting pointer is not of the same type
             // as the ptr_ty (we want a *T, not a *[N]T), and hence we need to use accessChain.
@@ -3283,9 +3262,7 @@ const DeclGen = struct {
         const mod = self.module;
         const bin_op = self.air.instructions.items(.data)[inst].bin_op;
         const array_ty = self.typeOf(bin_op.lhs);
-        const array_ty_ref = try self.resolveType(array_ty, .direct);
         const elem_ty = array_ty.childType(mod);
-        const elem_ty_ref = try self.resolveType(elem_ty, .indirect);
         const array_id = try self.resolve(bin_op.lhs);
         const index_id = try self.resolve(bin_op.rhs);
 
@@ -3293,20 +3270,10 @@ const DeclGen = struct {
         // For now, just generate a temporary and use that.
         // TODO: This backend probably also should use isByRef from llvm...
 
-        const array_ptr_ty_ref = try self.spv.ptrType(array_ty_ref, .Function);
-        const elem_ptr_ty_ref = try self.spv.ptrType(elem_ty_ref, .Function);
-
-        const tmp_id = self.spv.allocId();
-        try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
-            .id_result_type = self.typeId(array_ptr_ty_ref),
-            .id_result = tmp_id,
-            .storage_class = .Function,
-        });
-        try self.func.body.emit(self.spv.gpa, .OpStore, .{
-            .pointer = tmp_id,
-            .object = array_id,
-        });
+        const elem_ptr_ty_ref = try self.ptrType(elem_ty, .Function);
 
+        const tmp_id = try self.alloc(array_ty, .{ .storage_class = .Function });
+        try self.store(array_ty, tmp_id, array_id, false);
         const elem_ptr_id = try self.accessChainId(elem_ptr_ty_ref, tmp_id, &.{index_id});
         return try self.load(elem_ty, elem_ptr_id, false);
     }
@@ -3334,8 +3301,7 @@ const DeclGen = struct {
         if (layout.tag_size == 0) return;
 
         const tag_ty = un_ty.unionTagTypeSafety(mod).?;
-        const tag_ty_ref = try self.resolveType(tag_ty, .indirect);
-        const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, spvStorageClass(un_ptr_ty.ptrAddressSpace(mod)));
+        const tag_ptr_ty_ref = try self.ptrType(tag_ty, spvStorageClass(un_ptr_ty.ptrAddressSpace(mod)));
 
         const union_ptr_id = try self.resolve(bin_op.lhs);
         const new_tag_id = try self.resolve(bin_op.rhs);
@@ -3400,6 +3366,7 @@ const DeclGen = struct {
             return try self.constInt(tag_ty_ref, tag_int);
         }
 
+        // TODO: Make this use self.ptrType
         const un_active_ty_ref = try self.resolveUnionType(ty, active_field);
         const un_active_ptr_ty_ref = try self.spv.ptrType(un_active_ty_ref, .Function);
         const un_general_ty_ref = try self.resolveType(ty, .direct);
@@ -3414,23 +3381,16 @@ const DeclGen = struct {
 
         if (layout.tag_size != 0) {
             const tag_ty_ref = try self.resolveType(maybe_tag_ty.?, .direct);
-            const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, .Function);
+            const tag_ptr_ty_ref = try self.ptrType(maybe_tag_ty.?, .Function);
             const ptr_id = try self.accessChain(tag_ptr_ty_ref, tmp_id, &.{@as(u32, @intCast(layout.tag_index))});
             const tag_id = try self.constInt(tag_ty_ref, tag_int);
-            try self.func.body.emit(self.spv.gpa, .OpStore, .{
-                .pointer = ptr_id,
-                .object = tag_id,
-            });
+            try self.store(maybe_tag_ty.?, ptr_id, tag_id, false);
         }
 
         if (layout.active_field_size != 0) {
-            const active_field_ty_ref = try self.resolveType(layout.active_field_ty, .indirect);
-            const active_field_ptr_ty_ref = try self.spv.ptrType(active_field_ty_ref, .Function);
+            const active_field_ptr_ty_ref = try self.ptrType(layout.active_field_ty, .Function);
             const ptr_id = try self.accessChain(active_field_ptr_ty_ref, tmp_id, &.{@as(u32, @intCast(layout.active_field_index))});
-            try self.func.body.emit(self.spv.gpa, .OpStore, .{
-                .pointer = ptr_id,
-                .object = payload.?,
-            });
+            try self.store(layout.active_field_ty, ptr_id, payload.?, false);
         } else {
             assert(payload == null);
         }
@@ -3603,23 +3563,13 @@ const DeclGen = struct {
         return try self.structFieldPtr(result_ptr_ty, struct_ptr_ty, struct_ptr, field_index);
     }
 
-    /// We cannot use an OpVariable directly in an OpSpecConstantOp, but we can
-    /// after we insert a dummy AccessChain...
-    /// TODO: Get rid of this
-    fn makePointerConstant(
-        self: *DeclGen,
-        section: *SpvSection,
-        ptr_ty_ref: CacheRef,
-        ptr_id: IdRef,
-    ) !IdRef {
-        const result_id = self.spv.allocId();
-        try section.emitSpecConstantOp(self.spv.gpa, .OpInBoundsAccessChain, .{
-            .id_result_type = self.typeId(ptr_ty_ref),
-            .id_result = result_id,
-            .base = ptr_id,
-        });
-        return result_id;
-    }
+    const AllocOptions = struct {
+        initializer: ?IdRef = null,
+        /// The final storage class of the pointer. This may be either `.Generic` or `.Function`.
+        /// In either case, the local is allocated in the `.Function` storage class, and optionally
+        /// cast back to `.Generic`.
+        storage_class: StorageClass = .Generic,
+    };
 
     // Allocate a function-local variable, with possible initializer.
     // This function returns a pointer to a variable of type `ty_ref`,
@@ -3627,30 +3577,36 @@ const DeclGen = struct {
     // placed in the Function address space.
     fn alloc(
         self: *DeclGen,
-        ty_ref: CacheRef,
-        initializer: ?IdRef,
+        ty: Type,
+        options: AllocOptions,
     ) !IdRef {
-        const fn_ptr_ty_ref = try self.spv.ptrType(ty_ref, .Function);
-        const general_ptr_ty_ref = try self.spv.ptrType(ty_ref, .Generic);
+        const ptr_fn_ty_ref = try self.ptrType(ty, .Function);
 
         // SPIR-V requires that OpVariable declarations for locals go into the first block, so we are just going to
         // directly generate them into func.prologue instead of the body.
         const var_id = self.spv.allocId();
         try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
-            .id_result_type = self.typeId(fn_ptr_ty_ref),
+            .id_result_type = self.typeId(ptr_fn_ty_ref),
             .id_result = var_id,
             .storage_class = .Function,
-            .initializer = initializer,
+            .initializer = options.initializer,
         });
 
-        // Convert to a generic pointer
-        const result_id = self.spv.allocId();
-        try self.func.body.emit(self.spv.gpa, .OpPtrCastToGeneric, .{
-            .id_result_type = self.typeId(general_ptr_ty_ref),
-            .id_result = result_id,
-            .pointer = var_id,
-        });
-        return result_id;
+        switch (options.storage_class) {
+            .Generic => {
+                const ptr_gn_ty_ref = try self.ptrType(ty, .Generic);
+                // Convert to a generic pointer
+                const result_id = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpPtrCastToGeneric, .{
+                    .id_result_type = self.typeId(ptr_gn_ty_ref),
+                    .id_result = result_id,
+                    .pointer = var_id,
+                });
+                return result_id;
+            },
+            .Function => return var_id,
+            else => unreachable,
+        }
     }
 
     fn airAlloc(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -3659,8 +3615,7 @@ const DeclGen = struct {
         const ptr_ty = self.typeOfIndex(inst);
         assert(ptr_ty.ptrAddressSpace(mod) == .generic);
         const child_ty = ptr_ty.childType(mod);
-        const child_ty_ref = try self.resolveType(child_ty, .indirect);
-        return try self.alloc(child_ty_ref, null);
+        return try self.alloc(child_ty, .{});
     }
 
     fn airArg(self: *DeclGen) IdRef {
@@ -4032,7 +3987,7 @@ const DeclGen = struct {
                 .is_null => .eq,
                 .is_non_null => .neq,
             };
-            return try self.cmp(op, ptr_ty, ptr_id, null_id);
+            return try self.cmp(op, Type.bool, ptr_ty, ptr_id, null_id);
         }
 
         const is_non_null_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod))