Commit cbf2ee72e3

Robin Voetter <robin@voetter.nl>
2024-04-06 12:41:58
spirv: fix some recursive pointers edge cases in dedup pass
1 parent 125d332
Changed files (1)
src
link
src/link/SpirV/deduplicate.zig
@@ -47,6 +47,10 @@ const ModuleInfo = struct {
         result_id_index: u16,
         /// The first decoration in `self.decorations`.
         first_decoration: u32,
+
+        fn operands(self: Entity, binary: *const BinaryModule) []const Word {
+            return binary.instructions[self.first_operand..][0..self.num_operands];
+        }
     };
 
     /// Maps result-id to Entity's
@@ -210,10 +214,41 @@ const EntityContext = struct {
 
         const entity = self.info.entities.values()[index];
 
+        // If the current pointer is recursive, don't immediately add it to the map. This is to ensure that
+        // if the current pointer is already recursive, it gets the same hash a pointer that points to the
+        // same child but has a different result-id.
         if (entity.kind == .OpTypePointer) {
             // This may be either a pointer that is forward-referenced in the future,
             // or a forward reference to a pointer.
-            const entry = try self.ptr_map_a.getOrPut(self.a, id);
+            // Note: We use the **struct** here instead of the pointer itself, to avoid an edge case like this:
+            //
+            // A - C*'
+            //        \
+            //         C - C*'
+            //        /
+            // B - C*"
+            //
+            // In this case, hashing A goes like
+            //   A -> C*' -> C -> C*' recursion
+            // And hashing B goes like
+            //   B -> C*" -> C -> C*' -> C -> C*' recursion
+            // The are several calls to ptrType in codegen that may C*' and C*" to be generated as separate
+            // types. This is not a problem for C itself though - this can only be generated through resolveType()
+            // and so ensures equality by Zig's type system. Technically the above problem is still present, but it
+            // would only be present in a structure such as
+            //
+            // A - C*' - C'
+            //             \
+            //              C*" - C - C*
+            //             /
+            //            B
+            //
+            // where there is a duplicate definition of struct C. Resolving this requires a much more time consuming
+            // algorithm though, and because we don't expect any correctness issues with it, we leave that for now.
+
+            // TODO: Do we need to mind the storage class here? Its going to be recursive regardless, right?
+            const struct_id: ResultId = @enumFromInt(entity.operands(self.binary)[2]);
+            const entry = try self.ptr_map_a.getOrPut(self.a, struct_id);
             if (entry.found_existing) {
                 // Pointer already seen. Hash the index instead of recursing into its children.
                 std.hash.autoHash(hasher, entry.index);
@@ -228,12 +263,17 @@ const EntityContext = struct {
         for (decorations) |decoration| {
             try self.hashEntity(hasher, decoration);
         }
+
+        if (entity.kind == .OpTypePointer) {
+            const struct_id: ResultId = @enumFromInt(entity.operands(self.binary)[2]);
+            assert(self.ptr_map_a.swapRemove(struct_id));
+        }
     }
 
     fn hashEntity(self: *EntityContext, hasher: *std.hash.Wyhash, entity: ModuleInfo.Entity) !void {
         std.hash.autoHash(hasher, entity.kind);
         // Process operands
-        const operands = self.binary.instructions[entity.first_operand..][0..entity.num_operands];
+        const operands = entity.operands(self.binary);
         for (operands, 0..) |operand, i| {
             if (i == entity.result_id_index) {
                 // Not relevant, skip...
@@ -273,12 +313,19 @@ const EntityContext = struct {
         const entity_a = self.info.entities.values()[index_a];
         const entity_b = self.info.entities.values()[index_b];
 
+        if (entity_a.kind != entity_b.kind) {
+            return false;
+        }
+
         if (entity_a.kind == .OpTypePointer) {
             // May be a forward reference, or should be saved as a potential
             // forward reference in the future. Whatever the case, it should
             // be the same for both a and b.
-            const entry_a = try self.ptr_map_a.getOrPut(self.a, id_a);
-            const entry_b = try self.ptr_map_b.getOrPut(self.a, id_b);
+            const struct_id_a: ResultId = @enumFromInt(entity_a.operands(self.binary)[2]);
+            const struct_id_b: ResultId = @enumFromInt(entity_b.operands(self.binary)[2]);
+
+            const entry_a = try self.ptr_map_a.getOrPut(self.a, struct_id_a);
+            const entry_b = try self.ptr_map_b.getOrPut(self.a, struct_id_b);
 
             if (entry_a.found_existing != entry_b.found_existing) return false;
             if (entry_a.index != entry_b.index) return false;
@@ -306,6 +353,14 @@ const EntityContext = struct {
             }
         }
 
+        if (entity_a.kind == .OpTypePointer) {
+            const struct_id_a: ResultId = @enumFromInt(entity_a.operands(self.binary)[2]);
+            const struct_id_b: ResultId = @enumFromInt(entity_b.operands(self.binary)[2]);
+
+            assert(self.ptr_map_a.swapRemove(struct_id_a));
+            assert(self.ptr_map_b.swapRemove(struct_id_b));
+        }
+
         return true;
     }
 
@@ -316,8 +371,8 @@ const EntityContext = struct {
             return false;
         }
 
-        const operands_a = self.binary.instructions[entity_a.first_operand..][0..entity_a.num_operands];
-        const operands_b = self.binary.instructions[entity_b.first_operand..][0..entity_b.num_operands];
+        const operands_a = entity_a.operands(self.binary);
+        const operands_b = entity_b.operands(self.binary);
 
         // Note: returns false for operands that have explicit defaults in optional operands... oh well
         if (operands_a.len != operands_b.len) {
@@ -463,7 +518,7 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule, progress: *std.P
                 if (entity.kind == .OpTypePointer and !emitted_ptrs.contains(id)) {
                     // Grab the pointer's storage class from its operands in the original
                     // module.
-                    const storage_class: spec.StorageClass = @enumFromInt(binary.instructions[entity.first_operand + 1]);
+                    const storage_class: spec.StorageClass = @enumFromInt(entity.operands(binary)[1]);
                     try section.emit(a, .OpTypeForwardPointer, .{
                         .pointer_type = id,
                         .storage_class = storage_class,