Commit 3d5721da23

Robin Voetter <robin@voetter.nl>
2024-03-09 12:00:34
spirv: update spec generator
For module parsing and assembling, we will also need to know all of the SPIR-V extensions and their instructions. This commit updates the generator to generate those. Because there are multiple instruction sets that each have a separate list of Opcodes, no separate enum is generated for these opcodes. Additionally, the previous mechanism for runtime instruction information, `Opcode`'s `fn operands()`, has been removed in favor for `InstructionSet.core.instructions()`. Any mapping from operand to instruction is to be done at runtime. Using a runtime populated hashmap should also be more efficient than the previous mechanism using `stringToEnum`.
1 parent 3bffa58
Changed files (2)
tools/spirv/grammar.zig
@@ -22,8 +22,8 @@ pub const CoreRegistry = struct {
 };
 
 pub const ExtensionRegistry = struct {
-    copyright: [][]const u8,
-    version: u32,
+    copyright: ?[][]const u8 = null,
+    version: ?u32 = null,
     revision: u32,
     instructions: []Instruction,
     operand_kinds: []OperandKind = &[_]OperandKind{},
@@ -40,6 +40,8 @@ pub const Instruction = struct {
     opcode: u32,
     operands: []Operand = &[_]Operand{},
     capabilities: [][]const u8 = &[_][]const u8{},
+    // DebugModuleINTEL has this...
+    capability: ?[]const u8 = null,
     extensions: [][]const u8 = &[_][]const u8{},
     version: ?[]const u8 = null,
 
tools/gen_spirv_spec.zig
@@ -1,45 +1,110 @@
 const std = @import("std");
-const g = @import("spirv/grammar.zig");
 const Allocator = std.mem.Allocator;
+const g = @import("spirv/grammar.zig");
+const CoreRegistry = g.CoreRegistry;
+const ExtensionRegistry = g.ExtensionRegistry;
+const Instruction = g.Instruction;
+const OperandKind = g.OperandKind;
+const Enumerant = g.Enumerant;
+const Operand = g.Operand;
 
 const ExtendedStructSet = std.StringHashMap(void);
 
+const Extension = struct {
+    name: []const u8,
+    spec: ExtensionRegistry,
+};
+
+const CmpInst = struct {
+    fn lt(_: CmpInst, a: Instruction, b: Instruction) bool {
+        return a.opcode < b.opcode;
+    }
+};
+
+const StringPair = struct { []const u8, []const u8 };
+
+const StringPairContext = struct {
+    pub fn hash(_: @This(), a: StringPair) u32 {
+        var hasher = std.hash.Wyhash.init(0);
+        const x, const y = a;
+        hasher.update(x);
+        hasher.update(y);
+        return @truncate(hasher.final());
+    }
+
+    pub fn eql(_: @This(), a: StringPair, b: StringPair, b_index: usize) bool {
+        _ = b_index;
+        const a_x, const a_y = a;
+        const b_x, const b_y = b;
+        return std.mem.eql(u8, a_x, b_x) and std.mem.eql(u8, a_y, b_y);
+    }
+};
+
+const OperandKindMap = std.ArrayHashMap(StringPair, OperandKind, StringPairContext, true);
+
 pub fn main() !void {
     var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
     defer arena.deinit();
-    const allocator = arena.allocator();
+    const a = arena.allocator();
 
-    const args = try std.process.argsAlloc(allocator);
+    const args = try std.process.argsAlloc(a);
     if (args.len != 2) {
-        usageAndExit(std.io.getStdErr(), args[0], 1);
+        usageAndExit(args[0], 1);
+    }
+
+    const json_path = try std.fs.path.join(a, &.{ args[1], "include/spirv/unified1/" });
+    const dir = try std.fs.cwd().openDir(json_path, .{ .iterate = true });
+
+    // const spec_path = try std.fs.path.join(a, &.{spirv_headers_dir_path, "spirv.core.grammar.json"});
+    // const core_spec = try std.fs.cwd().readFileAlloc(a, spec_path, std.math.maxInt(usize));
+
+    const core_spec = try readRegistry(CoreRegistry, a, dir, "spirv.core.grammar.json");
+    std.sort.block(Instruction, core_spec.instructions, CmpInst{}, CmpInst.lt);
+
+    var exts = std.ArrayList(Extension).init(a);
+
+    var it = dir.iterate();
+    while (try it.next()) |entry| {
+        if (entry.kind != .file or !std.mem.startsWith(u8, entry.name, "extinst.")) {
+            continue;
+        }
+
+        std.debug.assert(std.mem.endsWith(u8, entry.name, ".grammar.json"));
+        const name = entry.name["extinst.".len .. entry.name.len - ".grammar.json".len];
+        const spec = try readRegistry(ExtensionRegistry, a, dir, entry.name);
+
+        std.sort.block(Instruction, spec.instructions, CmpInst{}, CmpInst.lt);
+
+        try exts.append(.{ .name = try a.dupe(u8, name), .spec = spec });
     }
 
-    const spec_path = args[1];
-    const spec = try std.fs.cwd().readFileAlloc(allocator, spec_path, std.math.maxInt(usize));
+    var bw = std.io.bufferedWriter(std.io.getStdOut().writer());
+    try render(bw.writer(), a, core_spec, exts.items);
+    try bw.flush();
+}
 
+fn readRegistry(comptime RegistryType: type, a: Allocator, dir: std.fs.Dir, path: []const u8) !RegistryType {
+    const spec = try dir.readFileAlloc(a, path, std.math.maxInt(usize));
     // Required for json parsing.
     @setEvalBranchQuota(10000);
 
-    var scanner = std.json.Scanner.initCompleteInput(allocator, spec);
+    var scanner = std.json.Scanner.initCompleteInput(a, spec);
     var diagnostics = std.json.Diagnostics{};
     scanner.enableDiagnostics(&diagnostics);
-    const parsed = std.json.parseFromTokenSource(g.CoreRegistry, allocator, &scanner, .{}) catch |err| {
-        std.debug.print("line,col: {},{}\n", .{ diagnostics.getLine(), diagnostics.getColumn() });
+    const parsed = std.json.parseFromTokenSource(RegistryType, a, &scanner, .{}) catch |err| {
+        std.debug.print("{s}:{}:{}:\n", .{ path, diagnostics.getLine(), diagnostics.getColumn() });
         return err;
     };
-
-    var bw = std.io.bufferedWriter(std.io.getStdOut().writer());
-    try render(bw.writer(), allocator, parsed.value);
-    try bw.flush();
+    return parsed.value;
 }
 
 /// Returns a set with types that require an extra struct for the `Instruction` interface
 /// to the spir-v spec, or whether the original type can be used.
 fn extendedStructs(
-    arena: Allocator,
-    kinds: []const g.OperandKind,
+    a: Allocator,
+    kinds: []const OperandKind,
 ) !ExtendedStructSet {
-    var map = ExtendedStructSet.init(arena);
+    var map = ExtendedStructSet.init(a);
     try map.ensureTotalCapacity(@as(u32, @intCast(kinds.len)));
 
     for (kinds) |kind| {
@@ -73,7 +138,7 @@ fn tagPriorityScore(tag: []const u8) usize {
     }
 }
 
-fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void {
+fn render(writer: anytype, a: Allocator, registry: CoreRegistry, extensions: []const Extension) !void {
     try writer.writeAll(
         \\//! This file is auto-generated by tools/gen_spirv_spec.zig.
         \\
@@ -99,6 +164,7 @@ fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void
         \\pub const IdScope = IdRef;
         \\
         \\pub const LiteralInteger = Word;
+        \\pub const LiteralFloat = Word;
         \\pub const LiteralString = []const u8;
         \\pub const LiteralContextDependentNumber = union(enum) {
         \\    int32: i32,
@@ -139,6 +205,12 @@ fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void
         \\    parameters: []const OperandKind,
         \\};
         \\
+        \\pub const Instruction = struct {
+        \\    name: []const u8,
+        \\    opcode: Word,
+        \\    operands: []const Operand,
+        \\};
+        \\
         \\
     );
 
@@ -151,15 +223,123 @@ fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void
         .{ registry.major_version, registry.minor_version, registry.revision, registry.magic_number },
     );
 
-    const extended_structs = try extendedStructs(allocator, registry.operand_kinds);
-    try renderClass(writer, allocator, registry.instructions);
-    try renderOperandKind(writer, registry.operand_kinds);
-    try renderOpcodes(writer, allocator, registry.instructions, extended_structs);
-    try renderOperandKinds(writer, allocator, registry.operand_kinds, extended_structs);
+    // Merge the operand kinds from all extensions together.
+    // var all_operand_kinds = std.ArrayList(OperandKind).init(a);
+    // try all_operand_kinds.appendSlice(registry.operand_kinds);
+    var all_operand_kinds = OperandKindMap.init(a);
+    for (registry.operand_kinds) |kind| {
+        try all_operand_kinds.putNoClobber(.{ "core", kind.kind }, kind);
+    }
+    for (extensions) |ext| {
+        // Note: extensions may define the same operand kind, with different
+        // parameters. Instead of trying to merge them, just discriminate them
+        // using the name of the extension. This is similar to what
+        // the official headers do.
+
+        try all_operand_kinds.ensureUnusedCapacity(ext.spec.operand_kinds.len);
+        for (ext.spec.operand_kinds) |kind| {
+            var new_kind = kind;
+            new_kind.kind = try std.mem.join(a, ".", &.{ ext.name, kind.kind });
+            try all_operand_kinds.putNoClobber(.{ ext.name, kind.kind }, new_kind);
+        }
+    }
+
+    const extended_structs = try extendedStructs(a, all_operand_kinds.values());
+    // Note: extensions don't seem to have class.
+    try renderClass(writer, a, registry.instructions);
+    try renderOperandKind(writer, all_operand_kinds.values());
+    try renderOpcodes(writer, a, registry.instructions, extended_structs);
+    try renderOperandKinds(writer, a, all_operand_kinds.values(), extended_structs);
+    try renderInstructionSet(writer, a, registry, extensions, all_operand_kinds);
+}
+
+fn renderInstructionSet(
+    writer: anytype,
+    a: Allocator,
+    core: CoreRegistry,
+    extensions: []const Extension,
+    all_operand_kinds: OperandKindMap,
+) !void {
+    _ = a;
+    try writer.writeAll(
+        \\pub const InstructionSet = enum {
+        \\    core,
+    );
+
+    for (extensions) |ext| {
+        try writer.print("{},\n", .{std.zig.fmtId(ext.name)});
+    }
+
+    try writer.writeAll(
+        \\
+        \\    pub fn instructions(self: InstructionSet) []const Instruction {
+        \\        return switch (self) {
+        \\
+    );
+
+    try renderInstructionsCase(writer, "core", core.instructions, all_operand_kinds);
+    for (extensions) |ext| {
+        try renderInstructionsCase(writer, ext.name, ext.spec.instructions, all_operand_kinds);
+    }
+
+    try writer.writeAll(
+        \\        };
+        \\    }
+        \\};
+        \\
+    );
+}
+
+fn renderInstructionsCase(
+    writer: anytype,
+    set_name: []const u8,
+    instructions: []const Instruction,
+    all_operand_kinds: OperandKindMap,
+) !void {
+    // Note: theoretically we could dedup from tags and give every instruction a list of aliases,
+    // but there aren't so many total aliases and that would add more overhead in total. We will
+    // just filter those out when needed.
+
+    try writer.print(".{} => &[_]Instruction{{\n", .{std.zig.fmtId(set_name)});
+
+    for (instructions) |inst| {
+        try writer.print(
+            \\.{{
+            \\    .name = "{s}",
+            \\    .opcode = {},
+            \\    .operands = &[_]Operand{{
+            \\
+        , .{ inst.opname, inst.opcode });
+
+        for (inst.operands) |operand| {
+            const quantifier = if (operand.quantifier) |q|
+                switch (q) {
+                    .@"?" => "optional",
+                    .@"*" => "variadic",
+                }
+            else
+                "required";
+
+            const kind = all_operand_kinds.get(.{ set_name, operand.kind }) orelse
+                all_operand_kinds.get(.{ "core", operand.kind }).?;
+            try writer.print(".{{.kind = .{}, .quantifier = .{s}}},\n", .{ std.zig.fmtId(kind.kind), quantifier });
+        }
+
+        try writer.writeAll(
+            \\    },
+            \\},
+            \\
+        );
+    }
+
+    try writer.writeAll(
+        \\},
+        \\
+    );
 }
 
-fn renderClass(writer: anytype, allocator: Allocator, instructions: []const g.Instruction) !void {
-    var class_map = std.StringArrayHashMap(void).init(allocator);
+fn renderClass(writer: anytype, a: Allocator, instructions: []const Instruction) !void {
+    var class_map = std.StringArrayHashMap(void).init(a);
 
     for (instructions) |inst| {
         if (std.mem.eql(u8, inst.class.?, "@exclude")) {
@@ -173,7 +353,7 @@ fn renderClass(writer: anytype, allocator: Allocator, instructions: []const g.In
         try renderInstructionClass(writer, class);
         try writer.writeAll(",\n");
     }
-    try writer.writeAll("};\n");
+    try writer.writeAll("};\n\n");
 }
 
 fn renderInstructionClass(writer: anytype, class: []const u8) !void {
@@ -192,7 +372,7 @@ fn renderInstructionClass(writer: anytype, class: []const u8) !void {
     }
 }
 
-fn renderOperandKind(writer: anytype, operands: []const g.OperandKind) !void {
+fn renderOperandKind(writer: anytype, operands: []const OperandKind) !void {
     try writer.writeAll("pub const OperandKind = enum {\n");
     for (operands) |operand| {
         try writer.print("{},\n", .{std.zig.fmtId(operand.kind)});
@@ -242,7 +422,7 @@ fn renderOperandKind(writer: anytype, operands: []const g.OperandKind) !void {
     try writer.writeAll("};\n}\n};\n");
 }
 
-fn renderEnumerant(writer: anytype, enumerant: g.Enumerant) !void {
+fn renderEnumerant(writer: anytype, enumerant: Enumerant) !void {
     try writer.print(".{{.name = \"{s}\", .value = ", .{enumerant.enumerant});
     switch (enumerant.value) {
         .bitflag => |flag| try writer.writeAll(flag),
@@ -260,14 +440,14 @@ fn renderEnumerant(writer: anytype, enumerant: g.Enumerant) !void {
 
 fn renderOpcodes(
     writer: anytype,
-    allocator: Allocator,
-    instructions: []const g.Instruction,
+    a: Allocator,
+    instructions: []const Instruction,
     extended_structs: ExtendedStructSet,
 ) !void {
-    var inst_map = std.AutoArrayHashMap(u32, usize).init(allocator);
+    var inst_map = std.AutoArrayHashMap(u32, usize).init(a);
     try inst_map.ensureTotalCapacity(instructions.len);
 
-    var aliases = std.ArrayList(struct { inst: usize, alias: usize }).init(allocator);
+    var aliases = std.ArrayList(struct { inst: usize, alias: usize }).init(a);
     try aliases.ensureTotalCapacity(instructions.len);
 
     for (instructions, 0..) |inst, i| {
@@ -323,31 +503,6 @@ fn renderOpcodes(
         try renderOperand(writer, .instruction, inst.opname, inst.operands, extended_structs);
     }
 
-    try writer.writeAll(
-        \\};
-        \\}
-        \\pub fn operands(self: Opcode) []const Operand {
-        \\return switch (self) {
-        \\
-    );
-
-    for (instructions_indices) |i| {
-        const inst = instructions[i];
-        try writer.print(".{} => &[_]Operand{{", .{std.zig.fmtId(inst.opname)});
-        for (inst.operands) |operand| {
-            const quantifier = if (operand.quantifier) |q|
-                switch (q) {
-                    .@"?" => "optional",
-                    .@"*" => "variadic",
-                }
-            else
-                "required";
-
-            try writer.print(".{{.kind = .{s}, .quantifier = .{s}}},", .{ operand.kind, quantifier });
-        }
-        try writer.writeAll("},\n");
-    }
-
     try writer.writeAll(
         \\};
         \\}
@@ -368,14 +523,14 @@ fn renderOpcodes(
 
 fn renderOperandKinds(
     writer: anytype,
-    allocator: Allocator,
-    kinds: []const g.OperandKind,
+    a: Allocator,
+    kinds: []const OperandKind,
     extended_structs: ExtendedStructSet,
 ) !void {
     for (kinds) |kind| {
         switch (kind.category) {
-            .ValueEnum => try renderValueEnum(writer, allocator, kind, extended_structs),
-            .BitEnum => try renderBitEnum(writer, allocator, kind, extended_structs),
+            .ValueEnum => try renderValueEnum(writer, a, kind, extended_structs),
+            .BitEnum => try renderBitEnum(writer, a, kind, extended_structs),
             else => {},
         }
     }
@@ -383,20 +538,26 @@ fn renderOperandKinds(
 
 fn renderValueEnum(
     writer: anytype,
-    allocator: Allocator,
-    enumeration: g.OperandKind,
+    a: Allocator,
+    enumeration: OperandKind,
     extended_structs: ExtendedStructSet,
 ) !void {
     const enumerants = enumeration.enumerants orelse return error.InvalidRegistry;
 
-    var enum_map = std.AutoArrayHashMap(u32, usize).init(allocator);
+    var enum_map = std.AutoArrayHashMap(u32, usize).init(a);
     try enum_map.ensureTotalCapacity(enumerants.len);
 
-    var aliases = std.ArrayList(struct { enumerant: usize, alias: usize }).init(allocator);
+    var aliases = std.ArrayList(struct { enumerant: usize, alias: usize }).init(a);
     try aliases.ensureTotalCapacity(enumerants.len);
 
     for (enumerants, 0..) |enumerant, i| {
-        const result = enum_map.getOrPutAssumeCapacity(enumerant.value.int);
+        try writer.context.flush();
+        const value: u31 = switch (enumerant.value) {
+            .int => |value| value,
+            // Some extensions declare ints as string
+            .bitflag => |value| try std.fmt.parseInt(u31, value, 10),
+        };
+        const result = enum_map.getOrPutAssumeCapacity(value);
         if (!result.found_existing) {
             result.value_ptr.* = i;
             continue;
@@ -422,9 +583,12 @@ fn renderValueEnum(
 
     for (enum_indices) |i| {
         const enumerant = enumerants[i];
-        if (enumerant.value != .int) return error.InvalidRegistry;
+        // if (enumerant.value != .int) return error.InvalidRegistry;
 
-        try writer.print("{} = {},\n", .{ std.zig.fmtId(enumerant.enumerant), enumerant.value.int });
+        switch (enumerant.value) {
+            .int => |value| try writer.print("{} = {},\n", .{ std.zig.fmtId(enumerant.enumerant), value }),
+            .bitflag => |value| try writer.print("{} = {s},\n", .{ std.zig.fmtId(enumerant.enumerant), value }),
+        }
     }
 
     try writer.writeByte('\n');
@@ -454,8 +618,8 @@ fn renderValueEnum(
 
 fn renderBitEnum(
     writer: anytype,
-    allocator: Allocator,
-    enumeration: g.OperandKind,
+    a: Allocator,
+    enumeration: OperandKind,
     extended_structs: ExtendedStructSet,
 ) !void {
     try writer.print("pub const {s} = packed struct {{\n", .{std.zig.fmtId(enumeration.kind)});
@@ -463,7 +627,7 @@ fn renderBitEnum(
     var flags_by_bitpos = [_]?usize{null} ** 32;
     const enumerants = enumeration.enumerants orelse return error.InvalidRegistry;
 
-    var aliases = std.ArrayList(struct { flag: usize, alias: u5 }).init(allocator);
+    var aliases = std.ArrayList(struct { flag: usize, alias: u5 }).init(a);
     try aliases.ensureTotalCapacity(enumerants.len);
 
     for (enumerants, 0..) |enumerant, i| {
@@ -471,6 +635,10 @@ fn renderBitEnum(
         const value = try parseHexInt(enumerant.value.bitflag);
         if (value == 0) {
             continue; // Skip 'none' items
+        } else if (std.mem.eql(u8, enumerant.enumerant, "FlagIsPublic")) {
+            // This flag is special and poorly defined in the json files.
+            // Just skip it for now
+            continue;
         }
 
         std.debug.assert(@popCount(value) == 1);
@@ -540,7 +708,7 @@ fn renderOperand(
         mask,
     },
     field_name: []const u8,
-    parameters: []const g.Operand,
+    parameters: []const Operand,
     extended_structs: ExtendedStructSet,
 ) !void {
     if (kind == .instruction) {
@@ -606,7 +774,7 @@ fn renderOperand(
     try writer.writeAll(",\n");
 }
 
-fn renderFieldName(writer: anytype, operands: []const g.Operand, field_index: usize) !void {
+fn renderFieldName(writer: anytype, operands: []const Operand, field_index: usize) !void {
     const operand = operands[field_index];
 
     // Should be enough for all names - adjust as needed.
@@ -673,16 +841,16 @@ fn parseHexInt(text: []const u8) !u31 {
     return try std.fmt.parseInt(u31, text[prefix.len..], 16);
 }
 
-fn usageAndExit(file: std.fs.File, arg0: []const u8, code: u8) noreturn {
-    file.writer().print(
-        \\Usage: {s} <spirv json spec>
+fn usageAndExit(arg0: []const u8, code: u8) noreturn {
+    std.io.getStdErr().writer().print(
+        \\Usage: {s} <SPIRV-Headers repository path>
         \\
-        \\Generates Zig bindings for a SPIR-V specification .json (either core or
-        \\extinst versions). The result, printed to stdout, should be used to update
+        \\Generates Zig bindings for SPIR-V specifications found in the SPIRV-Headers
+        \\repository. The result, printed to stdout, should be used to update
         \\files in src/codegen/spirv. Don't forget to format the output.
         \\
-        \\The relevant specifications can be obtained from the SPIR-V registry:
-        \\https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/unified1/
+        \\<SPIRV-Headers repository path> should point to a clone of
+        \\https://github.com/KhronosGroup/SPIRV-Headers/
         \\
     , .{arg0}) catch std.process.exit(1);
     std.process.exit(code);