Commit 42c7e752e1

Robin Voetter <robin@voetter.nl>
2024-03-31 19:05:54
spirv: id range helper
This allows us to more sanely allocate a continuous range of result-ids, and avoids a bunch of nasty casting code in a few places. Its currently not used very often, but will be useful in the future.
1 parent 3942083
Changed files (2)
src
codegen
src/codegen/spirv/Module.zig
@@ -197,14 +197,26 @@ pub fn deinit(self: *Module) void {
     self.* = undefined;
 }
 
-pub fn allocId(self: *Module) spec.IdResult {
-    defer self.next_result_id += 1;
-    return @enumFromInt(self.next_result_id);
-}
+pub const IdRange = struct {
+    base: u32,
+    len: u32,
+
+    pub fn at(range: IdRange, i: usize) IdResult {
+        assert(i < range.len);
+        return @enumFromInt(range.base + i);
+    }
+};
 
-pub fn allocIds(self: *Module, n: u32) spec.IdResult {
+pub fn allocIds(self: *Module, n: u32) IdRange {
     defer self.next_result_id += n;
-    return @enumFromInt(self.next_result_id);
+    return .{
+        .base = self.next_result_id,
+        .len = n,
+    };
+}
+
+pub fn allocId(self: *Module) IdResult {
+    return self.allocIds(1).at(0);
 }
 
 pub fn idBound(self: Module) Word {
src/codegen/spirv.zig
@@ -5440,7 +5440,7 @@ const DeclGen = struct {
         };
 
         // First, pre-allocate the labels for the cases.
-        const first_case_label = self.spv.allocIds(num_cases);
+        const case_labels = self.spv.allocIds(num_cases);
         // We always need the default case - if zig has none, we will generate unreachable there.
         const default = self.spv.allocId();
 
@@ -5471,7 +5471,7 @@ const DeclGen = struct {
                 const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len];
                 extra_index = case.end + case.data.items_len + case_body.len;
 
-                const label: IdRef = @enumFromInt(@intFromEnum(first_case_label) + case_i);
+                const label = case_labels.at(case_i);
 
                 for (items) |item| {
                     const value = (try self.air.value(item, mod)) orelse unreachable;
@@ -5511,7 +5511,7 @@ const DeclGen = struct {
             const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]);
             extra_index = case.end + case.data.items_len + case_body.len;
 
-            const label: IdResult = @enumFromInt(@intFromEnum(first_case_label) + case_i);
+            const label = case_labels.at(case_i);
 
             try self.beginSpvBlock(label);