Commit ced8a2c3a6

Robin Voetter <robin@voetter.nl>
2023-08-26 12:19:28
spirv: add type_map to map AIR types to SPIR-V types
This will help us both to make the implementation a little more efficient by caching emission for certain types like structs, and also allow us to attach extra information about types that we can use while lowering without performing a search over the entire type tree for some property.
1 parent 79f7481
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -12,6 +12,7 @@ const LazySrcLoc = Module.LazySrcLoc;
 const Air = @import("../Air.zig");
 const Zir = @import("../Zir.zig");
 const Liveness = @import("../Liveness.zig");
+const InternPool = @import("../InternPool.zig");
 
 const spec = @import("spirv/spec.zig");
 const Opcode = spec.Opcode;
@@ -30,6 +31,15 @@ const SpvAssembler = @import("spirv/Assembler.zig");
 
 const InstMap = std.AutoHashMapUnmanaged(Air.Inst.Index, IdRef);
 
+/// We want to store some extra facts about types as mapped from Zig to SPIR-V.
+/// This structure is used to keep that extra information, as well as
+/// the cached reference to the type.
+const SpvTypeInfo = struct {
+    ty_ref: CacheRef,
+};
+
+const TypeMap = std.AutoHashMapUnmanaged(InternPool.Index, SpvTypeInfo);
+
 const IncomingBlock = struct {
     src_label_id: IdRef,
     break_value_id: IdRef,
@@ -78,6 +88,15 @@ pub const DeclGen = struct {
     /// A map keeping track of which instruction generated which result-id.
     inst_results: InstMap = .{},
 
+    /// A map that maps AIR intern pool indices to SPIR-V cache references (which
+    /// is basically the same thing except for SPIR-V).
+    /// This map is typically only used for structures that are deemed heavy enough
+    /// that it is worth to store them here. The SPIR-V module also interns types,
+    /// and so the main purpose of this map is to avoid recomputation and to
+    /// cache extra information about the type rather than to aid in validity
+    /// of the SPIR-V module.
+    type_map: TypeMap = .{},
+
     /// We need to keep track of result ids for block labels, as well as the 'incoming'
     /// blocks for a block.
     blocks: BlockMap = .{},
@@ -207,6 +226,7 @@ pub const DeclGen = struct {
     pub fn deinit(self: *DeclGen) void {
         self.args.deinit(self.gpa);
         self.inst_results.deinit(self.gpa);
+        self.type_map.deinit(self.gpa);
         self.blocks.deinit(self.gpa);
         self.func.deinit(self.gpa);
     }
@@ -1180,6 +1200,9 @@ pub const DeclGen = struct {
             return try self.resolveType(union_obj.enum_tag_ty.toType(), .indirect);
         }
 
+        const entry = try self.type_map.getOrPut(self.gpa, ty.toIntern());
+        if (entry.found_existing) return entry.value_ptr.ty_ref;
+
         var member_types = std.BoundedArray(CacheRef, 4){};
         var member_names = std.BoundedArray(CacheString, 4){};
 
@@ -1222,10 +1245,16 @@ pub const DeclGen = struct {
             member_names.appendAssumeCapacity(try self.spv.resolveString("padding"));
         }
 
-        return try self.spv.resolve(.{ .struct_type = .{
+        const ty_ref = try self.spv.resolve(.{ .struct_type = .{
             .member_types = member_types.slice(),
             .member_names = member_names.slice(),
         } });
+
+        entry.value_ptr.* = .{
+            .ty_ref = ty_ref,
+        };
+
+        return ty_ref;
     }
 
     /// Turn a Zig type into a SPIR-V Type, and return a reference to it.
@@ -1268,15 +1297,26 @@ pub const DeclGen = struct {
                 return try self.spv.resolve(.{ .float_type = .{ .bits = bits } });
             },
             .Array => {
+                const entry = try self.type_map.getOrPut(self.gpa, ty.toIntern());
+                if (entry.found_existing) return entry.value_ptr.ty_ref;
+
                 const elem_ty = ty.childType(mod);
-                const elem_ty_ref = try self.resolveType(elem_ty, .direct);
+                const elem_ty_ref = try self.resolveType(elem_ty, .indirect);
                 const total_len = std.math.cast(u32, ty.arrayLenIncludingSentinel(mod)) orelse {
                     return self.fail("array type of {} elements is too large", .{ty.arrayLenIncludingSentinel(mod)});
                 };
-                return self.spv.arrayType(total_len, elem_ty_ref);
+                const ty_ref = try self.spv.arrayType(total_len, elem_ty_ref);
+                entry.value_ptr.* = .{
+                    .ty_ref = ty_ref,
+                };
+                return ty_ref;
             },
             .Fn => switch (repr) {
                 .direct => {
+                    const entry = try self.type_map.getOrPut(self.gpa, ty.toIntern());
+                    if (entry.found_existing) return entry.value_ptr.ty_ref;
+
+                    const ip = &mod.intern_pool;
                     const fn_info = mod.typeToFunc(ty).?;
                     // TODO: Put this somewhere in Sema.zig
                     if (fn_info.is_var_args)
@@ -1289,10 +1329,16 @@ pub const DeclGen = struct {
                     }
                     const return_ty_ref = try self.resolveType(fn_info.return_type.toType(), .direct);
 
-                    return try self.spv.resolve(.{ .function_type = .{
+                    const ty_ref = try self.spv.resolve(.{ .function_type = .{
                         .return_type = return_ty_ref,
                         .parameters = param_ty_refs,
                     } });
+
+                    entry.value_ptr.* = .{
+                        .ty_ref = ty_ref,
+                    };
+
+                    return ty_ref;
                 },
                 .indirect => {
                     // TODO: Represent function pointers properly.
@@ -1338,6 +1384,9 @@ pub const DeclGen = struct {
                 } });
             },
             .Struct => {
+                const entry = try self.type_map.getOrPut(self.gpa, ty.toIntern());
+                if (entry.found_existing) return entry.value_ptr.ty_ref;
+
                 const struct_type = switch (ip.indexToKey(ty.toIntern())) {
                     .anon_struct_type => |tuple| {
                         const member_types = try self.gpa.alloc(CacheRef, tuple.values.len);
@@ -1351,9 +1400,14 @@ pub const DeclGen = struct {
                             member_index += 1;
                         }
 
-                        return try self.spv.resolve(.{ .struct_type = .{
+                        const ty_ref = try self.spv.resolve(.{ .struct_type = .{
                             .member_types = member_types[0..member_index],
                         } });
+
+                        entry.value_ptr.* = .{
+                            .ty_ref = ty_ref,
+                        };
+                        return ty_ref;
                     },
                     .struct_type => |struct_type| struct_type,
                     else => unreachable,
@@ -1361,7 +1415,6 @@ pub const DeclGen = struct {
 
                 if (struct_type.layout == .Packed) {
                     return try self.resolveType(struct_type.backingIntType(ip).toType(), .direct);
-                }
 
                 var member_types = std.ArrayList(CacheRef).init(self.gpa);
                 defer member_types.deinit();
@@ -1379,11 +1432,16 @@ pub const DeclGen = struct {
 
                 const name = ip.stringToSlice(try mod.declPtr(struct_type.decl.unwrap().?).getFullyQualifiedName(mod));
 
-                return try self.spv.resolve(.{ .struct_type = .{
+                const ty_ref = try self.spv.resolve(.{ .struct_type = .{
                     .name = try self.spv.resolveString(name),
                     .member_types = member_types.items,
                     .member_names = member_names.items,
                 } });
+
+                entry.value_ptr.* = .{
+                    .ty_ref = ty_ref,
+                };
+                return ty_ref;
             },
             .Optional => {
                 const payload_ty = ty.optionalChild(mod);
@@ -1400,15 +1458,23 @@ pub const DeclGen = struct {
                     return payload_ty_ref;
                 }
 
+                const entry = try self.type_map.getOrPut(self.gpa, ty.toIntern());
+                if (entry.found_existing) return entry.value_ptr.ty_ref;
+
                 const bool_ty_ref = try self.resolveType(Type.bool, .indirect);
 
-                return try self.spv.resolve(.{ .struct_type = .{
+                const ty_ref = try self.spv.resolve(.{ .struct_type = .{
                     .member_types = &.{ payload_ty_ref, bool_ty_ref },
                     .member_names = &.{
                         try self.spv.resolveString("payload"),
                         try self.spv.resolveString("valid"),
                     },
                 } });
+
+                entry.value_ptr.* = .{
+                    .ty_ref = ty_ref,
+                };
+                return ty_ref;
             },
             .Union => return try self.resolveUnionType(ty, null),
             .ErrorSet => return try self.intType(.unsigned, 16),
@@ -1421,6 +1487,9 @@ pub const DeclGen = struct {
                     return error_ty_ref;
                 }
 
+                const entry = try self.type_map.getOrPut(self.gpa, ty.toIntern());
+                if (entry.found_existing) return entry.value_ptr.ty_ref;
+
                 const payload_ty_ref = try self.resolveType(payload_ty, .indirect);
 
                 var member_types: [2]CacheRef = undefined;
@@ -1443,10 +1512,15 @@ pub const DeclGen = struct {
                     // TODO: ABI padding?
                 }
 
-                return try self.spv.resolve(.{ .struct_type = .{
+                const ty_ref = try self.spv.resolve(.{ .struct_type = .{
                     .member_types = &member_types,
                     .member_names = &member_names,
                 } });
+
+                entry.value_ptr.* = .{
+                    .ty_ref = ty_ref,
+                };
+                return ty_ref;
             },
 
             .Null,