Commit f5ab3c93c9

Robin Voetter <robin@voetter.nl>
2024-03-30 10:06:55
spirv: handle annotations in deduplication pass
1 parent b496039
Changed files (1)
src
link
src/link/SpirV/deduplicate.zig
@@ -17,7 +17,8 @@ fn canDeduplicate(opcode: Opcode) bool {
             // These are deprecated, so don't bother supporting them for now.
             return false;
         },
-        .OpName, .OpMemberName => true, // Debug decoration-style instructions
+        // Debug decoration-style instructions
+        .OpName, .OpMemberName => true,
         else => switch (opcode.class()) {
             .TypeDeclaration,
             .ConstantCreation,
@@ -44,6 +45,8 @@ const ModuleInfo = struct {
         /// or the entity that is affected by this entity if this entity
         /// is a decoration.
         result_id_index: u16,
+        /// The first decoration in `self.decorations`.
+        first_decoration: u32,
     };
 
     /// Maps result-id to Entity's
@@ -53,6 +56,8 @@ const ModuleInfo = struct {
     /// Because we need these values when recoding the module anyway,
     /// it contains the status of ALL operands in the module.
     operand_is_id: std.DynamicBitSetUnmanaged,
+    /// Store of decorations for each entity.
+    decorations: []const Entity,
 
     pub fn parse(
         arena: Allocator,
@@ -62,6 +67,7 @@ const ModuleInfo = struct {
         var entities = std.AutoArrayHashMap(ResultId, Entity).init(arena);
         var id_offsets = std.ArrayList(u16).init(arena);
         var operand_is_id = try std.DynamicBitSetUnmanaged.initEmpty(arena, binary.instructions.len);
+        var decorations = std.MultiArrayList(struct { target_id: ResultId, entity: Entity }){};
 
         var it = binary.iterateInstructions();
         while (it.next()) |inst| {
@@ -82,10 +88,20 @@ const ModuleInfo = struct {
             };
 
             const result_id: ResultId = @enumFromInt(inst.operands[id_offsets.items[result_id_index]]);
+            const entity = Entity{
+                .kind = inst.opcode,
+                .first_operand = first_operand_offset,
+                .num_operands = @intCast(inst.operands.len),
+                .result_id_index = result_id_index,
+                .first_decoration = undefined, // Filled in later
+            };
 
             switch (inst.opcode.class()) {
                 .Annotation, .Debug => {
-                    // TODO
+                    try decorations.append(arena, .{
+                        .target_id = result_id,
+                        .entity = entity,
+                    });
                 },
                 .TypeDeclaration, .ConstantCreation => {
                     const entry = try entities.getOrPut(result_id);
@@ -93,22 +109,67 @@ const ModuleInfo = struct {
                         log.err("type or constant {} has duplicate definition", .{result_id});
                         return error.DuplicateId;
                     }
-                    entry.value_ptr.* = .{
-                        .kind = inst.opcode,
-                        .first_operand = first_operand_offset,
-                        .num_operands = @intCast(inst.operands.len),
-                        .result_id_index = result_id_index,
-                    };
+                    entry.value_ptr.* = entity;
                 },
                 else => unreachable,
             }
         }
 
+        // Sort decorations by the index of the result-id in `entities.
+        // This ensures not only that the decorations of a particular reuslt-id
+        // are continuous, but the subsequences also appear in the same order as in `entities`.
+
+        const SortContext = struct {
+            entities: std.AutoArrayHashMapUnmanaged(ResultId, Entity),
+            ids: []const ResultId,
+
+            pub fn lessThan(ctx: @This(), a_index: usize, b_index: usize) bool {
+                // If any index is not in the entities set, its because its not a
+                // deduplicatable result-id. Those should be considered largest and
+                // float to the end.
+                const entity_index_a = ctx.entities.getIndex(ctx.ids[a_index]) orelse return false;
+                const entity_index_b = ctx.entities.getIndex(ctx.ids[b_index]) orelse return true;
+
+                return entity_index_a < entity_index_b;
+            }
+        };
+
+        decorations.sort(SortContext{
+            .entities = entities.unmanaged,
+            .ids = decorations.items(.target_id),
+        });
+
+        // Now go through the decorations and add the offsets to the entities list.
+        var decoration_i: u32 = 0;
+        const target_ids = decorations.items(.target_id);
+        for (entities.keys(), entities.values()) |id, *entity| {
+            entity.first_decoration = decoration_i;
+
+            // Scan ahead to the next decoration
+            while (decoration_i < target_ids.len and target_ids[decoration_i] == id) {
+                decoration_i += 1;
+            }
+        }
+
         return ModuleInfo{
             .entities = entities.unmanaged,
             .operand_is_id = operand_is_id,
+            // There may be unrelated decorations at the end, so make sure to
+            // slice those off.
+            .decorations = decorations.items(.entity)[0..decoration_i],
         };
     }
+
+    fn entityDecorationsByIndex(self: ModuleInfo, index: usize) []const Entity {
+        const values = self.entities.values();
+        const first_decoration = values[index].first_decoration;
+        if (index == values.len - 1) {
+            return self.decorations[first_decoration..];
+        } else {
+            const next_first_decoration = values[index + 1].first_decoration;
+            return self.decorations[first_decoration..next_first_decoration];
+        }
+    }
 };
 
 const EntityContext = struct {
@@ -138,23 +199,39 @@ const EntityContext = struct {
         return hasher.final();
     }
 
-    fn hashInner(self: *EntityContext, hasher: *std.hash.Wyhash, id: ResultId) !void {
-        const index = self.info.entities.getIndex(id).?;
+    fn hashInner(self: *EntityContext, hasher: *std.hash.Wyhash, id: ResultId) error{OutOfMemory}!void {
+        const index = self.info.entities.getIndex(id) orelse {
+            // Index unknown, the type or constant may depend on another result-id
+            // that couldn't be deduplicated and so it wasn't added to info.entities.
+            // In this case, just has the ID itself.
+            std.hash.autoHash(hasher, id);
+            return;
+        };
+
         const entity = self.info.entities.values()[index];
 
-        std.hash.autoHash(hasher, entity.kind);
         if (entity.kind == .OpTypePointer) {
             // This may be either a pointer that is forward-referenced in the future,
             // or a forward reference to a pointer.
             const entry = try self.ptr_map_a.getOrPut(self.a, id);
             if (entry.found_existing) {
                 // Pointer already seen. Hash the index instead of recursing into its children.
-                // TODO: Discriminate this path somehow?
                 std.hash.autoHash(hasher, entry.index);
                 return;
             }
         }
 
+        try self.hashEntity(hasher, entity);
+
+        // Process decorations.
+        const decorations = self.info.entityDecorationsByIndex(index);
+        for (decorations) |decoration| {
+            try self.hashEntity(hasher, decoration);
+        }
+    }
+
+    fn hashEntity(self: *EntityContext, hasher: *std.hash.Wyhash, entity: ModuleInfo.Entity) !void {
+        std.hash.autoHash(hasher, entity.kind);
         // Process operands
         const operands = self.binary.instructions[entity.first_operand..][0..entity.num_operands];
         for (operands, 0..) |operand, i| {
@@ -178,19 +255,24 @@ const EntityContext = struct {
         return try self.eqlInner(a, b);
     }
 
-    fn eqlInner(self: *EntityContext, id_a: ResultId, id_b: ResultId) !bool {
-        const index_a = self.info.entities.getIndex(id_a).?;
-        const index_b = self.info.entities.getIndex(id_b).?;
+    fn eqlInner(self: *EntityContext, id_a: ResultId, id_b: ResultId) error{OutOfMemory}!bool {
+        const maybe_index_a = self.info.entities.getIndex(id_a);
+        const maybe_index_b = self.info.entities.getIndex(id_b);
+
+        if (maybe_index_a == null and maybe_index_b == null) {
+            // Both indices unknown. In this case the type or constant
+            // may depend on another result-id that couldn't be deduplicated
+            // (so it wasn't added to info.entities). In this case, that particular
+            // result-id should be the same one.
+            return id_a == id_b;
+        }
+
+        const index_a = maybe_index_a orelse return false;
+        const index_b = maybe_index_b orelse return false;
 
         const entity_a = self.info.entities.values()[index_a];
         const entity_b = self.info.entities.values()[index_b];
 
-        if (entity_a.kind != entity_b.kind) {
-            return false;
-        } else if (entity_a.result_id_index != entity_a.result_id_index) {
-            return false;
-        }
-
         if (entity_a.kind == .OpTypePointer) {
             // May be a forward reference, or should be saved as a potential
             // forward reference in the future. Whatever the case, it should
@@ -207,6 +289,33 @@ const EntityContext = struct {
             }
         }
 
+        if (!try self.eqlEntities(entity_a, entity_b)) {
+            return false;
+        }
+
+        // Compare decorations.
+        const decorations_a = self.info.entityDecorationsByIndex(index_a);
+        const decorations_b = self.info.entityDecorationsByIndex(index_b);
+        if (decorations_a.len != decorations_b.len) {
+            return false;
+        }
+
+        for (decorations_a, decorations_b) |decoration_a, decoration_b| {
+            if (!try self.eqlEntities(decoration_a, decoration_b)) {
+                return false;
+            }
+        }
+
+        return true;
+    }
+
+    fn eqlEntities(self: *EntityContext, entity_a: ModuleInfo.Entity, entity_b: ModuleInfo.Entity) !bool {
+        if (entity_a.kind != entity_b.kind) {
+            return false;
+        } else if (entity_a.result_id_index != entity_a.result_id_index) {
+            return false;
+        }
+
         const operands_a = self.binary.instructions[entity_a.first_operand..][0..entity_a.num_operands];
         const operands_b = self.binary.instructions[entity_b.first_operand..][0..entity_b.num_operands];
 
@@ -260,7 +369,6 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void {
     const a = arena.allocator();
 
     const info = try ModuleInfo.parse(a, parser, binary.*);
-    log.info("added {} entities", .{info.entities.count()});
 
     // Hash all keys once so that the maps can be allocated the right size.
     var ctx = EntityContext{
@@ -280,10 +388,9 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void {
         .entity_context = &ctx,
     });
     var replace = std.AutoArrayHashMap(ResultId, ResultId).init(a);
-    for (info.entities.keys(), info.entities.values()) |id, entity| {
+    for (info.entities.keys()) |id| {
         const entry = try map.getOrPut(id);
         if (entry.found_existing) {
-            log.info("deduplicating {} - {s} (prior definition: {})", .{ id, @tagName(entity.kind), entry.key_ptr.* });
             try replace.putNoClobber(id, entry.key_ptr.*);
         }
     }
@@ -297,13 +404,15 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void {
     while (it.next()) |inst| {
         // Result-id can only be the first or second operand
         const inst_spec = parser.getInstSpec(inst.opcode).?;
-        const maybe_result_id: ?ResultId = for (0..2) |i| {
+
+        const maybe_result_id_offset: ?u16 = for (0..2) |i| {
             if (inst_spec.operands.len > i and inst_spec.operands[i].kind == .IdResult) {
-                break @enumFromInt(inst.operands[i]);
+                break @intCast(i);
             }
         } else null;
 
-        if (maybe_result_id) |result_id| {
+        if (maybe_result_id_offset) |offset| {
+            const result_id: ResultId = @enumFromInt(inst.operands[offset]);
             if (replace.contains(result_id)) continue;
         }
 
@@ -312,8 +421,16 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void {
                 new_functions_section = section.instructions.items.len;
             },
             .OpTypeForwardPointer => continue, // We re-emit these where needed
-            // TODO: These aren't supported yet, strip them out for testing purposes.
-            .OpName, .OpMemberName => continue,
+            else => {},
+        }
+
+        switch (inst.opcode.class()) {
+            .Annotation, .Debug => {
+                // For decoration-style instructions, only emit them
+                // if the target is not removed.
+                const target: ResultId = @enumFromInt(inst.operands[0]);
+                if (replace.contains(target)) continue;
+            },
             else => {},
         }
 
@@ -330,9 +447,8 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void {
                 operand.* = @intFromEnum(new_id);
             }
 
-            const id: ResultId = @enumFromInt(operand.*);
-            // TODO: This test is a little janky. Check the offset instead?
-            if (maybe_result_id == null or maybe_result_id.? != id) {
+            if (maybe_result_id_offset == null or maybe_result_id_offset.? != i) {
+                const id: ResultId = @enumFromInt(operand.*);
                 const index = info.entities.getIndex(id) orelse continue;
                 const entity = info.entities.values()[index];
                 if (entity.kind == .OpTypePointer and !emitted_ptrs.contains(id)) {
@@ -349,7 +465,8 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void {
         }
 
         if (inst.opcode == .OpTypePointer) {
-            try emitted_ptrs.put(maybe_result_id.?, {});
+            const result_id: ResultId = @enumFromInt(new_operands.items[maybe_result_id_offset.?]);
+            try emitted_ptrs.put(result_id, {});
         }
 
         try section.emitRawInstruction(a, inst.opcode, new_operands.items);