Commit ff042e8006

Robin Voetter <robin@voetter.nl>
2022-01-21 15:05:02
spirv: improve generator
The spirv spec generator now also generates some support information: Opcode gains a function to query a Zig type representing the operands of the opcode. The idea is that this will enable a richer interface for emitting spirv instructions.
1 parent 0e6d218
Changed files (1)
tools/gen_spirv_spec.zig
@@ -1,5 +1,8 @@
 const std = @import("std");
 const g = @import("spirv/grammar.zig");
+const Allocator = std.mem.Allocator;
+
+const ExtendedStructSet = std.StringHashMap(void);
 
 pub fn main() !void {
     var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
@@ -20,101 +23,308 @@ pub fn main() !void {
     var tokens = std.json.TokenStream.init(spec);
     var registry = try std.json.parse(g.Registry, &tokens, .{ .allocator = allocator });
 
+    const core_reg = switch (registry) {
+        .core => |core_reg| core_reg,
+        .extension => return error.TODOSpirVExtensionSpec,
+    };
+
     var bw = std.io.bufferedWriter(std.io.getStdOut().writer());
-    try render(bw.writer(), registry);
+    try render(bw.writer(), allocator, core_reg);
     try bw.flush();
 }
 
-fn render(writer: anytype, registry: g.Registry) !void {
+/// Returns a set with types that require an extra struct for the `Instruction` interface
+/// to the spir-v spec, or whether the original type can be used.
+fn extendedStructs(
+    arena: Allocator,
+    kinds: []const g.OperandKind,
+) !ExtendedStructSet {
+    var map = ExtendedStructSet.init(arena);
+    try map.ensureTotalCapacity(@intCast(u32, kinds.len));
+
+    for (kinds) |kind| {
+        const enumerants = kind.enumerants orelse continue;
+
+        for (enumerants) |enumerant| {
+            if (enumerant.parameters.len > 0) {
+                break;
+            }
+        } else continue;
+
+        map.putAssumeCapacity(kind.kind, {});
+    }
+
+    return map;
+}
+
+// Return a score for a particular priority. Duplicate instruction/operand enum values are
+// removed by picking the tag with the lowest score to keep, and by making an alias for the
+// other. Note that the tag does not need to be just a tag at this point, in which case it
+// gets the lowest score automatically anyway.
+fn tagPriorityScore(tag: []const u8) usize {
+    if (tag.len == 0) {
+        return 1;
+    } else if (std.mem.eql(u8, tag, "EXT")) {
+        return 2;
+    } else if (std.mem.eql(u8, tag, "KHR")) {
+        return 3;
+    } else {
+        return 4;
+    }
+}
+
+fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void {
     try writer.writeAll(
         \\//! This file is auto-generated by tools/gen_spirv_spec.zig.
         \\
         \\const Version = @import("std").builtin.Version;
         \\
+        \\pub const Word = u32;
+        \\pub const IdResultType = struct{
+        \\    id: Word,
+        \\    pub fn toRef(self: IdResultType) IdRef {
+        \\        return .{.id = self.id};
+        \\    }
+        \\};
+        \\pub const IdResult = struct{
+        \\    id: Word,
+        \\    pub fn toRef(self: IdResult) IdRef {
+        \\        return .{.id = self.id};
+        \\    }
+        \\    pub fn toResultType(self: IdResult) IdResultType {
+        \\        return .{.id = self.id};
+        \\    }
+        \\};
+        \\pub const IdRef = struct{ id: Word };
+        \\
+        \\pub const IdMemorySemantics = IdRef;
+        \\pub const IdScope = IdRef;
+        \\
+        \\pub const LiteralInteger = Word;
+        \\pub const LiteralString = []const u8;
+        \\pub const LiteralContextDependentNumber = union(enum) {
+        \\    int32: i32,
+        \\    uint32: u32,
+        \\    int64: i64,
+        \\    uint64: u64,
+        \\    float32: f32,
+        \\    float64: f64,
+        \\};
+        \\pub const LiteralExtInstInteger = struct{ inst: Word };
+        \\pub const LiteralSpecConstantOpInteger = struct { opcode: Opcode };
+        \\pub const PairLiteralIntegerIdRef = struct { value: LiteralInteger, label: IdRef };
+        \\pub const PairIdRefLiteralInteger = struct { target: IdRef, member: LiteralInteger };
+        \\pub const PairIdRefIdRef = [2]IdRef;
+        \\
+        \\
     );
 
-    switch (registry) {
-        .core => |core_reg| {
-            try writer.print(
-                \\pub const version = Version{{ .major = {}, .minor = {}, .patch = {} }};
-                \\pub const magic_number: u32 = {s};
-                \\
-            ,
-                .{ core_reg.major_version, core_reg.minor_version, core_reg.revision, core_reg.magic_number },
-            );
-            try renderOpcodes(writer, core_reg.instructions);
-            try renderOperandKinds(writer, core_reg.operand_kinds);
-        },
-        .extension => |ext_reg| {
-            try writer.print(
-                \\pub const version = Version{{ .major = {}, .minor = 0, .patch = {} }};
-                \\
-            ,
-                .{ ext_reg.version, ext_reg.revision },
-            );
-            try renderOpcodes(writer, ext_reg.instructions);
-            try renderOperandKinds(writer, ext_reg.operand_kinds);
-        },
-    }
+    try writer.print(
+        \\pub const version = Version{{ .major = {}, .minor = {}, .patch = {} }};
+        \\pub const magic_number: Word = {s};
+        \\
+    ,
+        .{ registry.major_version, registry.minor_version, registry.revision, registry.magic_number },
+    );
+    const extended_structs = try extendedStructs(allocator, registry.operand_kinds);
+    try renderOpcodes(writer, allocator, registry.instructions, extended_structs);
+    try renderOperandKinds(writer, allocator, registry.operand_kinds, extended_structs);
 }
 
-fn renderOpcodes(writer: anytype, instructions: []const g.Instruction) !void {
-    try writer.writeAll("pub const Opcode = extern enum(u16) {\n");
-    for (instructions) |instr| {
-        try writer.print("    {} = {},\n", .{ std.zig.fmtId(instr.opname), instr.opcode });
+fn renderOpcodes(
+    writer: anytype,
+    allocator: Allocator,
+    instructions: []const g.Instruction,
+    extended_structs: ExtendedStructSet,
+) !void {
+    var inst_map = std.AutoArrayHashMap(u32, usize).init(allocator);
+    try inst_map.ensureTotalCapacity(instructions.len);
+
+    var aliases = std.ArrayList(struct { inst: usize, alias: usize }).init(allocator);
+    try aliases.ensureTotalCapacity(instructions.len);
+
+    for (instructions) |inst, i| {
+        const result = inst_map.getOrPutAssumeCapacity(inst.opcode);
+        if (!result.found_existing) {
+            result.value_ptr.* = i;
+            continue;
+        }
+
+        const existing = instructions[result.value_ptr.*];
+
+        const tag_index = std.mem.indexOfDiff(u8, inst.opname, existing.opname).?;
+        const inst_priority = tagPriorityScore(inst.opname[tag_index..]);
+        const existing_priority = tagPriorityScore(existing.opname[tag_index..]);
+
+        if (inst_priority < existing_priority) {
+            aliases.appendAssumeCapacity(.{ .inst = result.value_ptr.*, .alias = i });
+            result.value_ptr.* = i;
+        } else {
+            aliases.appendAssumeCapacity(.{ .inst = i, .alias = result.value_ptr.* });
+        }
     }
-    try writer.writeAll("    _,\n};\n");
+
+    const instructions_indices = inst_map.values();
+
+    try writer.writeAll("pub const Opcode = enum(u16) {\n");
+    for (instructions_indices) |i| {
+        const inst = instructions[i];
+        try writer.print("{} = {},\n", .{ std.zig.fmtId(inst.opname), inst.opcode });
+    }
+
+    try writer.writeByte('\n');
+
+    for (aliases.items) |alias| {
+        try writer.print("pub const {} = Opcode.{};\n", .{
+            std.zig.fmtId(instructions[alias.inst].opname),
+            std.zig.fmtId(instructions[alias.alias].opname),
+        });
+    }
+
+    try writer.writeAll(
+        \\
+        \\pub fn Operands(comptime self: Opcode) type {
+        \\return switch (self) {
+        \\
+    );
+
+    for (instructions_indices) |i| {
+        const inst = instructions[i];
+        try renderOperand(writer, .instruction, inst.opname, inst.operands, extended_structs);
+    }
+    try writer.writeAll("};\n}\n};\n");
+    _ = extended_structs;
 }
 
-fn renderOperandKinds(writer: anytype, kinds: []const g.OperandKind) !void {
+fn renderOperandKinds(
+    writer: anytype,
+    allocator: Allocator,
+    kinds: []const g.OperandKind,
+    extended_structs: ExtendedStructSet,
+) !void {
     for (kinds) |kind| {
         switch (kind.category) {
-            .ValueEnum => try renderValueEnum(writer, kind),
-            .BitEnum => try renderBitEnum(writer, kind),
+            .ValueEnum => try renderValueEnum(writer, allocator, kind, extended_structs),
+            .BitEnum => try renderBitEnum(writer, allocator, kind, extended_structs),
             else => {},
         }
     }
 }
 
-fn renderValueEnum(writer: anytype, enumeration: g.OperandKind) !void {
-    try writer.print("pub const {s} = extern enum(u32) {{\n", .{enumeration.kind});
-
+fn renderValueEnum(
+    writer: anytype,
+    allocator: Allocator,
+    enumeration: g.OperandKind,
+    extended_structs: ExtendedStructSet,
+) !void {
     const enumerants = enumeration.enumerants orelse return error.InvalidRegistry;
-    for (enumerants) |enumerant| {
+
+    var enum_map = std.AutoArrayHashMap(u32, usize).init(allocator);
+    try enum_map.ensureTotalCapacity(enumerants.len);
+
+    var aliases = std.ArrayList(struct { enumerant: usize, alias: usize }).init(allocator);
+    try aliases.ensureTotalCapacity(enumerants.len);
+
+    for (enumerants) |enumerant, i| {
+        const result = enum_map.getOrPutAssumeCapacity(enumerant.value.int);
+        if (!result.found_existing) {
+            result.value_ptr.* = i;
+            continue;
+        }
+
+        const existing = enumerants[result.value_ptr.*];
+
+        const tag_index = std.mem.indexOfDiff(u8, enumerant.enumerant, existing.enumerant).?;
+        const enum_priority = tagPriorityScore(enumerant.enumerant[tag_index..]);
+        const existing_priority = tagPriorityScore(existing.enumerant[tag_index..]);
+
+        if (enum_priority < existing_priority) {
+            aliases.appendAssumeCapacity(.{ .enumerant = result.value_ptr.*, .alias = i });
+            result.value_ptr.* = i;
+        } else {
+            aliases.appendAssumeCapacity(.{ .enumerant = i, .alias = result.value_ptr.* });
+        }
+    }
+
+    const enum_indices = enum_map.values();
+
+    try writer.print("pub const {s} = enum(u32) {{\n", .{std.zig.fmtId(enumeration.kind)});
+
+    for (enum_indices) |i| {
+        const enumerant = enumerants[i];
         if (enumerant.value != .int) return error.InvalidRegistry;
 
-        try writer.print("    {} = {},\n", .{ std.zig.fmtId(enumerant.enumerant), enumerant.value.int });
+        try writer.print("{} = {},\n", .{ std.zig.fmtId(enumerant.enumerant), enumerant.value.int });
+    }
+
+    try writer.writeByte('\n');
+
+    for (aliases.items) |alias| {
+        try writer.print("pub const {} = {}.{};\n", .{
+            std.zig.fmtId(enumerants[alias.enumerant].enumerant),
+            std.zig.fmtId(enumeration.kind),
+            std.zig.fmtId(enumerants[alias.alias].enumerant),
+        });
+    }
+
+    if (!extended_structs.contains(enumeration.kind)) {
+        try writer.writeAll("};\n");
+        return;
+    }
+
+    try writer.print("\npub const Extended = union({}) {{\n", .{std.zig.fmtId(enumeration.kind)});
+
+    for (enum_indices) |i| {
+        const enumerant = enumerants[i];
+        try renderOperand(writer, .@"union", enumerant.enumerant, enumerant.parameters, extended_structs);
     }
 
-    try writer.writeAll("    _,\n};\n");
+    try writer.writeAll("};\n};\n");
 }
 
-fn renderBitEnum(writer: anytype, enumeration: g.OperandKind) !void {
-    try writer.print("pub const {s} = packed struct {{\n", .{enumeration.kind});
+fn renderBitEnum(
+    writer: anytype,
+    allocator: Allocator,
+    enumeration: g.OperandKind,
+    extended_structs: ExtendedStructSet,
+) !void {
+    try writer.print("pub const {s} = packed struct {{\n", .{std.zig.fmtId(enumeration.kind)});
 
-    var flags_by_bitpos = [_]?[]const u8{null} ** 32;
+    var flags_by_bitpos = [_]?usize{null} ** 32;
     const enumerants = enumeration.enumerants orelse return error.InvalidRegistry;
-    for (enumerants) |enumerant| {
+
+    var aliases = std.ArrayList(struct { flag: usize, alias: u5 }).init(allocator);
+    try aliases.ensureTotalCapacity(enumerants.len);
+
+    for (enumerants) |enumerant, i| {
         if (enumerant.value != .bitflag) return error.InvalidRegistry;
         const value = try parseHexInt(enumerant.value.bitflag);
-        if (@popCount(u32, value) != 1) {
-            continue; // Skip combinations and 'none' items
+        if (@popCount(u32, value) == 0) {
+            continue; // Skip 'none' items
         }
 
+        std.debug.assert(@popCount(u32, value) == 1);
+
         var bitpos = std.math.log2_int(u32, value);
         if (flags_by_bitpos[bitpos]) |*existing| {
-            // Keep the shortest
-            if (enumerant.enumerant.len < existing.len)
-                existing.* = enumerant.enumerant;
+            const tag_index = std.mem.indexOfDiff(u8, enumerant.enumerant, enumerants[existing.*].enumerant).?;
+            const enum_priority = tagPriorityScore(enumerant.enumerant[tag_index..]);
+            const existing_priority = tagPriorityScore(enumerants[existing.*].enumerant[tag_index..]);
+
+            if (enum_priority < existing_priority) {
+                aliases.appendAssumeCapacity(.{ .flag = existing.*, .alias = bitpos });
+                existing.* = i;
+            } else {
+                aliases.appendAssumeCapacity(.{ .flag = i, .alias = bitpos });
+            }
         } else {
-            flags_by_bitpos[bitpos] = enumerant.enumerant;
+            flags_by_bitpos[bitpos] = i;
         }
     }
 
-    for (flags_by_bitpos) |maybe_flag_name, bitpos| {
-        try writer.writeAll("    ");
-        if (maybe_flag_name) |flag_name| {
-            try writer.writeAll(flag_name);
+    for (flags_by_bitpos) |maybe_flag_index, bitpos| {
+        if (maybe_flag_index) |flag_index| {
+            try writer.print("{}", .{std.zig.fmtId(enumerants[flag_index].enumerant)});
         } else {
             try writer.print("_reserved_bit_{}", .{bitpos});
         }
@@ -126,7 +336,169 @@ fn renderBitEnum(writer: anytype, enumeration: g.OperandKind) !void {
         try writer.writeAll("= false,\n");
     }
 
-    try writer.writeAll("};\n");
+    try writer.writeByte('\n');
+
+    for (aliases.items) |alias| {
+        try writer.print("pub const {}: {} = .{{.{} = true}};\n", .{
+            std.zig.fmtId(enumerants[alias.flag].enumerant),
+            std.zig.fmtId(enumeration.kind),
+            std.zig.fmtId(enumerants[flags_by_bitpos[alias.alias].?].enumerant),
+        });
+    }
+
+    if (!extended_structs.contains(enumeration.kind)) {
+        try writer.writeAll("};\n");
+        return;
+    }
+
+    try writer.print("\npub const Extended = struct {{\n", .{});
+
+    for (flags_by_bitpos) |maybe_flag_index, bitpos| {
+        const flag_index = maybe_flag_index orelse {
+            try writer.print("_reserved_bit_{}: bool = false,\n", .{bitpos});
+            continue;
+        };
+        const enumerant = enumerants[flag_index];
+
+        try renderOperand(writer, .mask, enumerant.enumerant, enumerant.parameters, extended_structs);
+    }
+
+    try writer.writeAll("};\n};\n");
+}
+
+fn renderOperand(
+    writer: anytype,
+    kind: enum {
+        @"union",
+        instruction,
+        mask,
+    },
+    field_name: []const u8,
+    parameters: []const g.Operand,
+    extended_structs: ExtendedStructSet,
+) !void {
+    if (kind == .instruction) {
+        try writer.writeByte('.');
+    }
+    try writer.print("{}", .{std.zig.fmtId(field_name)});
+    if (parameters.len == 0) {
+        switch (kind) {
+            .@"union" => try writer.writeAll(",\n"),
+            .instruction => try writer.writeAll(" => void,\n"),
+            .mask => try writer.writeAll(": bool = false,\n"),
+        }
+        return;
+    }
+
+    if (kind == .instruction) {
+        try writer.writeAll(" => ");
+    } else {
+        try writer.writeAll(": ");
+    }
+
+    if (kind == .mask) {
+        try writer.writeByte('?');
+    }
+
+    try writer.writeAll("struct{");
+
+    for (parameters) |param, j| {
+        if (j != 0) {
+            try writer.writeAll(", ");
+        }
+
+        try renderFieldName(writer, parameters, j);
+        try writer.writeAll(": ");
+
+        if (param.quantifier) |q| {
+            switch (q) {
+                .@"?" => try writer.writeByte('?'),
+                .@"*" => try writer.writeAll("[]const "),
+            }
+        }
+
+        try writer.print("{}", .{std.zig.fmtId(param.kind)});
+
+        if (extended_structs.contains(param.kind)) {
+            try writer.writeAll(".Extended");
+        }
+
+        if (param.quantifier) |q| {
+            switch (q) {
+                .@"?" => try writer.writeAll(" = null"),
+                .@"*" => try writer.writeAll(" = &.{}"),
+            }
+        }
+    }
+
+    try writer.writeAll("}");
+
+    if (kind == .mask) {
+        try writer.writeAll(" = null");
+    }
+
+    try writer.writeAll(",\n");
+}
+
+fn renderFieldName(writer: anytype, operands: []const g.Operand, field_index: usize) !void {
+    const operand = operands[field_index];
+
+    // Should be enough for all names - adjust as needed.
+    var name_buffer = std.BoundedArray(u8, 64){
+        .buffer = undefined,
+    };
+
+    derive_from_kind: {
+        // Operand names are often in the json encoded as "'Name'" (with two sets of quotes).
+        // Additionally, some operands have ~ in them at the end (D~ref~).
+        const name = std.mem.trim(u8, operand.name, "'~");
+        if (name.len == 0) {
+            break :derive_from_kind;
+        }
+
+        // Some names have weird characters in them (like newlines) - skip any such ones.
+        // Use the same loop to transform to snake-case.
+        for (name) |c| {
+            switch (c) {
+                'a'...'z', '0'...'9' => try name_buffer.append(c),
+                'A'...'Z' => try name_buffer.append(std.ascii.toLower(c)),
+                ' ', '~' => try name_buffer.append('_'),
+                else => break :derive_from_kind,
+            }
+        }
+
+        // Assume there are no duplicate 'name' fields.
+        try writer.print("{}", .{std.zig.fmtId(name_buffer.slice())});
+        return;
+    }
+
+    // Translate to snake case.
+    name_buffer.len = 0;
+    for (operand.kind) |c, i| {
+        switch (c) {
+            'a'...'z', '0'...'9' => try name_buffer.append(c),
+            'A'...'Z' => if (i > 0 and std.ascii.isLower(operand.kind[i - 1])) {
+                try name_buffer.appendSlice(&[_]u8{ '_', std.ascii.toLower(c) });
+            } else {
+                try name_buffer.append(std.ascii.toLower(c));
+            },
+            else => unreachable, // Assume that the name is valid C-syntax (and contains no underscores).
+        }
+    }
+
+    try writer.print("{}", .{std.zig.fmtId(name_buffer.slice())});
+
+    // For fields derived from type name, there could be any amount.
+    // Simply check against all other fields, and if another similar one exists, add a number.
+    const need_extra_index = for (operands) |other_operand, i| {
+        if (i != field_index and std.mem.eql(u8, operand.kind, other_operand.kind)) {
+            break true;
+        }
+    } else false;
+
+    if (need_extra_index) {
+        try writer.print("_{}", .{field_index});
+    }
 }
 
 fn parseHexInt(text: []const u8) !u31 {
@@ -142,7 +514,7 @@ fn usageAndExit(file: std.fs.File, arg0: []const u8, code: u8) noreturn {
         \\
         \\Generates Zig bindings for a SPIR-V specification .json (either core or
         \\extinst versions). The result, printed to stdout, should be used to update
-        \\files in src/codegen/spirv.
+        \\files in src/codegen/spirv. Don't forget to format the output.
         \\
         \\The relevant specifications can be obtained from the SPIR-V registry:
         \\https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/unified1/