Commit 393a805741

Robin Voetter <robin@voetter.nl>
2024-03-29 13:54:28
spirv: deduplicate prototype
1 parent aff71c6
Changed files (3)
src/link/SpirV/BinaryModule.zig
@@ -94,6 +94,8 @@ pub const ParseError = error{
     DuplicateId,
     /// Some ID did not resolve.
     InvalidId,
+    /// This opcode or instruction is not supported yet.
+    UnsupportedOperation,
     /// Parser ran out of memory.
     OutOfMemory,
 };
src/link/SpirV/deduplicate.zig
@@ -0,0 +1,400 @@
+const std = @import("std");
+const Allocator = std.mem.Allocator;
+const log = std.log.scoped(.spirv_link);
+const assert = std.debug.assert;
+
+const BinaryModule = @import("BinaryModule.zig");
+const Section = @import("../../codegen/spirv/Section.zig");
+const spec = @import("../../codegen/spirv/spec.zig");
+const Opcode = spec.Opcode;
+const ResultId = spec.IdResult;
+const Word = spec.Word;
+
+fn canDeduplicate(opcode: Opcode) bool {
+    return switch (opcode) {
+        .OpTypeForwardPointer => false, // Don't need to handle these
+        .OpGroupDecorate, .OpGroupMemberDecorate => {
+            // These are deprecated, so don't bother supporting them for now.
+            return false;
+        },
+        .OpName, .OpMemberName => true, // Debug decoration-style instructions
+        else => switch (opcode.class()) {
+            .TypeDeclaration,
+            .ConstantCreation,
+            .Annotation,
+            => true,
+            else => false,
+        },
+    };
+}
+
+const ModuleInfo = struct {
+    /// This models a type, decoration or constant instruction
+    /// and its dependencies.
+    const Entity = struct {
+        /// The type that this entity represents. This is just
+        /// the instruction opcode.
+        kind: Opcode,
+        /// Offset of first child result-id, stored in entity_children.
+        /// These are the shallow entities appearing directly in the
+        /// type's instruction.
+        first_child: u32,
+        /// Offset to the first word of extra-data: Data in the instruction
+        /// that must be considered for uniqueness, but doesn't include
+        /// any IDs.
+        first_extra_data: u32,
+    };
+
+    /// Maps result-id to Entity's
+    entities: std.AutoArrayHashMapUnmanaged(ResultId, Entity),
+    /// The list of children per instruction.
+    entity_children: []const ResultId,
+    /// The list of extra data per instruction.
+    /// TODO: This is a bit awkward, maybe we need to store it some
+    /// other way?
+    extra_data: []const u32,
+
+    pub fn parse(
+        arena: Allocator,
+        parser: *BinaryModule.Parser,
+        binary: BinaryModule,
+    ) !ModuleInfo {
+        var entities = std.AutoArrayHashMap(ResultId, Entity).init(arena);
+        var entity_children = std.ArrayList(ResultId).init(arena);
+        var extra_data = std.ArrayList(u32).init(arena);
+        var id_offsets = std.ArrayList(u16).init(arena);
+
+        var it = binary.iterateInstructions();
+        while (it.next()) |inst| {
+            if (inst.opcode == .OpFunction) break; // No more declarations are possible
+            if (!canDeduplicate(inst.opcode)) continue;
+
+            id_offsets.items.len = 0;
+            try parser.parseInstructionResultIds(binary, inst, &id_offsets);
+
+            const result_id_index: u32 = switch (inst.opcode.class()) {
+                .TypeDeclaration, .Annotation, .Debug => 0,
+                .ConstantCreation => 1,
+                else => unreachable,
+            };
+
+            const result_id: ResultId = @enumFromInt(inst.operands[id_offsets.items[result_id_index]]);
+
+            const first_child: u32 = @intCast(entity_children.items.len);
+            const first_extra_data: u32 = @intCast(extra_data.items.len);
+
+            try entity_children.ensureUnusedCapacity(id_offsets.items.len - 1);
+            try extra_data.ensureUnusedCapacity(inst.operands.len - id_offsets.items.len);
+
+            var id_i: usize = 0;
+            for (inst.operands, 0..) |operand, i| {
+                assert(id_i == id_offsets.items.len or id_offsets.items[id_i] >= i);
+                if (id_i != id_offsets.items.len and id_offsets.items[id_i] == i) {
+                    // Skip .IdResult / .IdResultType.
+                    if (id_i != result_id_index) {
+                        entity_children.appendAssumeCapacity(@enumFromInt(operand));
+                    }
+                    id_i += 1;
+                } else {
+                    // Non-id operand, add it to extra data.
+                    extra_data.appendAssumeCapacity(operand);
+                }
+            }
+
+            switch (inst.opcode.class()) {
+                .Annotation, .Debug => {
+                    // TODO
+                },
+                .TypeDeclaration, .ConstantCreation => {
+                    const entry = try entities.getOrPut(result_id);
+                    if (entry.found_existing) {
+                        log.err("type or constant {} has duplicate definition", .{result_id});
+                        return error.DuplicateId;
+                    }
+                    entry.value_ptr.* = .{
+                        .kind = inst.opcode,
+                        .first_child = first_child,
+                        .first_extra_data = first_extra_data,
+                    };
+                },
+                else => unreachable,
+            }
+        }
+
+        return ModuleInfo{
+            .entities = entities.unmanaged,
+            .entity_children = entity_children.items,
+            .extra_data = extra_data.items,
+        };
+    }
+
+    /// Fetch a slice of children for the index corresponding to an entity.
+    fn childrenByIndex(self: ModuleInfo, index: usize) []const ResultId {
+        const values = self.entities.values();
+        const first_child = values[index].first_child;
+        if (index == values.len - 1) {
+            return self.entity_children[first_child..];
+        } else {
+            const next_first_child = values[index + 1].first_child;
+            return self.entity_children[first_child..next_first_child];
+        }
+    }
+
+    /// Fetch the slice of extra-data for the index corresponding to an entity.
+    fn extraDataByIndex(self: ModuleInfo, index: usize) []const u32 {
+        const values = self.entities.values();
+        const first_extra_data = values[index].first_extra_data;
+        if (index == values.len - 1) {
+            return self.extra_data[first_extra_data..];
+        } else {
+            const next_extra_data = values[index + 1].first_extra_data;
+            return self.extra_data[first_extra_data..next_extra_data];
+        }
+    }
+};
+
+const EntityContext = struct {
+    a: Allocator,
+    ptr_map_a: std.AutoArrayHashMapUnmanaged(ResultId, void) = .{},
+    ptr_map_b: std.AutoArrayHashMapUnmanaged(ResultId, void) = .{},
+    info: *const ModuleInfo,
+
+    fn init(a: Allocator, info: *const ModuleInfo) EntityContext {
+        return .{
+            .a = a,
+            .info = info,
+        };
+    }
+
+    fn deinit(self: *EntityContext) void {
+        self.ptr_map_a.deinit(self.a);
+        self.ptr_map_b.deinit(self.a);
+
+        self.* = undefined;
+    }
+
+    fn equalizeMapCapacity(self: *EntityContext) !void {
+        const cap = @max(self.ptr_map_a.capacity(), self.ptr_map_b.capacity());
+        try self.ptr_map_a.ensureTotalCapacity(self.a, cap);
+        try self.ptr_map_b.ensureTotalCapacity(self.a, cap);
+    }
+
+    fn hash(self: *EntityContext, id: ResultId) !u64 {
+        var hasher = std.hash.Wyhash.init(0);
+        self.ptr_map_a.clearRetainingCapacity();
+        try self.hashInner(&hasher, id);
+        return hasher.final();
+    }
+
+    fn hashInner(self: *EntityContext, hasher: *std.hash.Wyhash, id: ResultId) !void {
+        const index = self.info.entities.getIndex(id).?;
+        const entity = self.info.entities.values()[index];
+
+        std.hash.autoHash(hasher, entity.kind);
+        if (entity.kind == .OpTypePointer) {
+            // This may be either a pointer that is forward-referenced in the future,
+            // or a forward reference to a pointer.
+            const entry = try self.ptr_map_a.getOrPut(self.a, id);
+            if (entry.found_existing) {
+                // Pointer already seen. Hash the index instead of recursing into its children.
+                // TODO: Discriminate this path somehow?
+                std.hash.autoHash(hasher, entry.index);
+                return;
+            }
+        }
+
+        // Hash extra data
+        for (self.info.extraDataByIndex(index)) |data| {
+            std.hash.autoHash(hasher, data);
+        }
+
+        // Hash children
+        for (self.info.childrenByIndex(index)) |child| {
+            try self.hashInner(hasher, child);
+        }
+    }
+
+    fn eql(self: *EntityContext, a: ResultId, b: ResultId) !bool {
+        self.ptr_map_a.clearRetainingCapacity();
+        self.ptr_map_b.clearRetainingCapacity();
+
+        return try self.eqlInner(a, b);
+    }
+
+    fn eqlInner(self: *EntityContext, id_a: ResultId, id_b: ResultId) !bool {
+        const index_a = self.info.entities.getIndex(id_a).?;
+        const index_b = self.info.entities.getIndex(id_b).?;
+
+        const entity_a = self.info.entities.values()[index_a];
+        const entity_b = self.info.entities.values()[index_b];
+
+        if (entity_a.kind != entity_b.kind) return false;
+
+        if (entity_a.kind == .OpTypePointer) {
+            // May be a forward reference, or should be saved as a potential
+            // forward reference in the future. Whatever the case, it should
+            // be the same for both a and b.
+            const entry_a = try self.ptr_map_a.getOrPut(self.a, id_a);
+            const entry_b = try self.ptr_map_b.getOrPut(self.a, id_b);
+
+            if (entry_a.found_existing != entry_b.found_existing) return false;
+            if (entry_a.index != entry_b.index) return false;
+
+            if (entry_a.found_existing) {
+                // No need to recurse.
+                return true;
+            }
+        }
+
+        // Check if extra data is the same.
+        if (!std.mem.eql(u32, self.info.extraDataByIndex(index_a), self.info.extraDataByIndex(index_b))) {
+            return false;
+        }
+
+        // Recursively check if children are the same
+        const children_a = self.info.childrenByIndex(index_a);
+        const children_b = self.info.childrenByIndex(index_b);
+        if (children_a.len != children_b.len) return false;
+
+        for (children_a, children_b) |child_a, child_b| {
+            if (!try self.eqlInner(child_a, child_b)) {
+                return false;
+            }
+        }
+
+        return true;
+    }
+};
+
+/// This struct is a wrapper around EntityContext that adapts it for
+/// use in a hash map. Because EntityContext allocates, it cannot be
+/// used. This wrapper simply assumes that the maps have been allocated
+/// the max amount of memory they are going to use.
+/// This is done by pre-hashing all keys.
+const EntityHashContext = struct {
+    entity_context: *EntityContext,
+
+    pub fn hash(self: EntityHashContext, key: ResultId) u64 {
+        return self.entity_context.hash(key) catch unreachable;
+    }
+
+    pub fn eql(self: EntityHashContext, a: ResultId, b: ResultId) bool {
+        return self.entity_context.eql(a, b) catch unreachable;
+    }
+};
+
+pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void {
+    var arena = std.heap.ArenaAllocator.init(parser.a);
+    defer arena.deinit();
+    const a = arena.allocator();
+
+    const info = try ModuleInfo.parse(a, parser, binary.*);
+    log.info("added {} entities", .{info.entities.count()});
+    log.info("children size: {}", .{info.entity_children.len});
+    log.info("extra data size: {}", .{info.extra_data.len});
+
+    // Hash all keys once so that the maps can be allocated the right size.
+    var ctx = EntityContext.init(a, &info);
+    for (info.entities.keys()) |id| {
+        _ = try ctx.hash(id);
+    }
+
+    // hash only uses ptr_map_a, so allocate ptr_map_b too
+    try ctx.equalizeMapCapacity();
+
+    // Figure out which entities can be deduplicated.
+    var map = std.HashMap(ResultId, void, EntityHashContext, 80).initContext(a, .{
+        .entity_context = &ctx,
+    });
+    var replace = std.AutoArrayHashMap(ResultId, ResultId).init(a);
+    for (info.entities.keys(), info.entities.values()) |id, entity| {
+        const entry = try map.getOrPut(id);
+        if (entry.found_existing) {
+            log.info("deduplicating {} - {s} (prior definition: {})", .{ id, @tagName(entity.kind), entry.key_ptr.* });
+            try replace.putNoClobber(id, entry.key_ptr.*);
+        }
+    }
+
+    // Now process the module, and replace instructions where needed.
+    var section = Section{};
+    var it = binary.iterateInstructions();
+    var id_offsets = std.ArrayList(u16).init(a);
+    var new_functions_section: ?usize = null;
+    var new_operands = std.ArrayList(u32).init(a);
+    var emitted_ptrs = std.AutoHashMap(ResultId, void).init(a);
+    while (it.next()) |inst| {
+        // Result-id can only be the first or second operand
+        const inst_spec = parser.getInstSpec(inst.opcode).?;
+        const maybe_result_id: ?ResultId = for (0..2) |i| {
+            if (inst_spec.operands.len > i and inst_spec.operands[i].kind == .IdResult) {
+                break @enumFromInt(inst.operands[i]);
+            }
+        } else null;
+
+        if (maybe_result_id) |result_id| {
+            if (replace.contains(result_id)) continue;
+        }
+
+        switch (inst.opcode) {
+            .OpFunction => if (new_functions_section == null) {
+                new_functions_section = section.instructions.items.len;
+            },
+            .OpTypeForwardPointer => continue, // We re-emit these where needed
+            // TODO: These aren't supported yet, strip them out for testing purposes.
+            .OpName, .OpMemberName => continue,
+            else => {},
+        }
+
+        // Re-emit the instruction, but replace all the IDs.
+
+        id_offsets.items.len = 0;
+        try parser.parseInstructionResultIds(binary.*, inst, &id_offsets);
+
+        new_operands.items.len = 0;
+        try new_operands.appendSlice(inst.operands);
+        for (id_offsets.items) |offset| {
+            {
+                const id: ResultId = @enumFromInt(inst.operands[offset]);
+                if (replace.get(id)) |new_id| {
+                    new_operands.items[offset] = @intFromEnum(new_id);
+                }
+            }
+
+            // TODO: Does this logic work? Maybe it will emit an OpTypeForwardPointer to
+            // something thats not a struct...
+            // It seems to work correctly on behavior.zig at least
+            const id: ResultId = @enumFromInt(new_operands.items[offset]);
+            if (maybe_result_id == null or maybe_result_id.? != id) {
+                const index = info.entities.getIndex(id) orelse continue;
+                const entity = info.entities.values()[index];
+                if (entity.kind == .OpTypePointer) {
+                    if (!emitted_ptrs.contains(id)) {
+                        // The storage class is in the extra data
+                        // TODO: This is kind of hacky...
+                        const extra_data = info.extraDataByIndex(index);
+                        const storage_class: spec.StorageClass = @enumFromInt(extra_data[0]);
+                        try section.emit(a, .OpTypeForwardPointer, .{
+                            .pointer_type = id,
+                            .storage_class = storage_class,
+                        });
+                        try emitted_ptrs.put(id, {});
+                    }
+                }
+            }
+        }
+
+        if (inst.opcode == .OpTypePointer) {
+            try emitted_ptrs.put(maybe_result_id.?, {});
+        }
+
+        try section.emitRawInstruction(a, inst.opcode, new_operands.items);
+    }
+
+    for (replace.keys()) |key| {
+        _ = binary.ext_inst_map.remove(key);
+        _ = binary.arith_type_width.remove(key);
+    }
+
+    binary.instructions = try parser.a.dupe(Word, section.toWords());
+    binary.sections.functions = new_functions_section orelse binary.instructions.len;
+}
src/link/SpirV.zig
@@ -261,6 +261,7 @@ fn linkModule(self: *SpirV, a: Allocator, module: []Word) ![]Word {
 
     const lower_invocation_globals = @import("SpirV/lower_invocation_globals.zig");
     const prune_unused = @import("SpirV/prune_unused.zig");
+    const dedup = @import("SpirV/deduplicate.zig");
 
     var parser = try BinaryModule.Parser.init(a);
     defer parser.deinit();
@@ -268,6 +269,7 @@ fn linkModule(self: *SpirV, a: Allocator, module: []Word) ![]Word {
 
     try lower_invocation_globals.run(&parser, &binary);
     try prune_unused.run(&parser, &binary);
+    try dedup.run(&parser, &binary);
 
     return binary.finalize(a);
 }