Commit ef638502d4

Robin Voetter <robin@voetter.nl>
2024-04-06 02:41:56
spirv: remove cache usage from assembler
1 parent 97a6776
Changed files (1)
src
codegen
src/codegen/spirv/Assembler.zig
@@ -9,10 +9,9 @@ const Opcode = spec.Opcode;
 const Word = spec.Word;
 const IdRef = spec.IdRef;
 const IdResult = spec.IdResult;
+const StorageClass = spec.StorageClass;
 
 const SpvModule = @import("Module.zig");
-const CacheRef = SpvModule.CacheRef;
-const CacheKey = SpvModule.CacheKey;
 
 /// Represents a token in the assembly template.
 const Token = struct {
@@ -127,16 +126,16 @@ const AsmValue = union(enum) {
     value: IdRef,
 
     /// This result-value represents a type registered into the module's type system.
-    ty: CacheRef,
+    ty: IdRef,
 
     /// Retrieve the result-id of this AsmValue. Asserts that this AsmValue
     /// is of a variant that allows the result to be obtained (not an unresolved
     /// forward declaration, not in the process of being declared, etc).
-    pub fn resultId(self: AsmValue, spv: *const SpvModule) IdRef {
+    pub fn resultId(self: AsmValue) IdRef {
         return switch (self) {
             .just_declared, .unresolved_forward_reference => unreachable,
             .value => |result| result,
-            .ty => |ref| spv.resultId(ref),
+            .ty => |result| result,
         };
     }
 };
@@ -292,23 +291,23 @@ fn processInstruction(self: *Assembler) !void {
 /// refers to the result.
 fn processTypeInstruction(self: *Assembler) !AsmValue {
     const operands = self.inst.operands.items;
-    const ref = switch (self.inst.opcode) {
-        .OpTypeVoid => try self.spv.resolve(.void_type),
-        .OpTypeBool => try self.spv.resolve(.bool_type),
+    const section = &self.spv.sections.types_globals_constants;
+    const id = switch (self.inst.opcode) {
+        .OpTypeVoid => try self.spv.voidType(),
+        .OpTypeBool => try self.spv.boolType(),
         .OpTypeInt => blk: {
-            // const signedness: std.builtin.Signedness = switch (operands[2].literal32) {
-            //     0 => .unsigned,
-            //     1 => .signed,
-            //     else => {
-            //         // TODO: Improve source location.
-            //         return self.fail(0, "{} is not a valid signedness (expected 0 or 1)", .{operands[2].literal32});
-            //     },
-            // };
-            // const width = std.math.cast(u16, operands[1].literal32) orelse {
-            //     return self.fail(0, "int type of {} bits is too large", .{operands[1].literal32});
-            // };
-            // break :blk try self.spv.intType(signedness, width);
-            break :blk @as(CacheRef, @enumFromInt(0)); // TODO(robin): fix
+            const signedness: std.builtin.Signedness = switch (operands[2].literal32) {
+                0 => .unsigned,
+                1 => .signed,
+                else => {
+                    // TODO: Improve source location.
+                    return self.fail(0, "{} is not a valid signedness (expected 0 or 1)", .{operands[2].literal32});
+                },
+            };
+            const width = std.math.cast(u16, operands[1].literal32) orelse {
+                return self.fail(0, "int type of {} bits is too large", .{operands[1].literal32});
+            };
+            break :blk try self.spv.intType(signedness, width);
         },
         .OpTypeFloat => blk: {
             const bits = operands[1].literal32;
@@ -318,43 +317,49 @@ fn processTypeInstruction(self: *Assembler) !AsmValue {
                     return self.fail(0, "{} is not a valid bit count for floats (expected 16, 32 or 64)", .{bits});
                 },
             }
-            break :blk try self.spv.resolve(.{ .float_type = .{ .bits = @intCast(bits) } });
+            break :blk try self.spv.floatType(@intCast(bits));
+        },
+        .OpTypeVector => blk: {
+            const child_type = try self.resolveRefId(operands[1].ref_id);
+            break :blk try self.spv.vectorType(operands[2].literal32, child_type);
         },
-        .OpTypeVector => try self.spv.resolve(.{ .vector_type = .{
-            .component_type = try self.resolveTypeRef(operands[1].ref_id),
-            .component_count = operands[2].literal32,
-        } }),
         .OpTypeArray => {
             // TODO: The length of an OpTypeArray is determined by a constant (which may be a spec constant),
             // and so some consideration must be taken when entering this in the type system.
             return self.todo("process OpTypeArray", .{});
         },
         .OpTypePointer => blk: {
-            break :blk try self.spv.resolve(.{
-                .ptr_type = .{
-                    .storage_class = @enumFromInt(operands[1].value),
-                    .child_type = try self.resolveTypeRef(operands[2].ref_id),
-                    // TODO: This should be a proper reference resolved via OpTypeForwardPointer
-                    .fwd = @enumFromInt(std.math.maxInt(u32)),
-                },
+            const storage_class: StorageClass = @enumFromInt(operands[1].value);
+            const child_type = try self.resolveRefId(operands[2].ref_id);
+            const result_id = self.spv.allocId();
+            try section.emit(self.spv.gpa, .OpTypePointer, .{
+                .id_result = result_id,
+                .storage_class = storage_class,
+                .type = child_type,
             });
+            break :blk result_id;
         },
         .OpTypeFunction => blk: {
             const param_operands = operands[2..];
-            const param_types = try self.spv.gpa.alloc(CacheRef, param_operands.len);
+            const return_type = try self.resolveRefId(operands[1].ref_id);
+
+            const param_types = try self.spv.gpa.alloc(IdRef, param_operands.len);
             defer self.spv.gpa.free(param_types);
-            for (param_types, 0..) |*param, i| {
-                param.* = try self.resolveTypeRef(param_operands[i].ref_id);
+            for (param_types, param_operands) |*param, operand| {
+                param.* = try self.resolveRefId(operand.ref_id);
             }
-            break :blk try self.spv.resolve(.{ .function_type = .{
-                .return_type = try self.resolveTypeRef(operands[1].ref_id),
-                .parameters = param_types,
-            } });
+            const result_id = self.spv.allocId();
+            try section.emit(self.spv.gpa, .OpTypeFunction, .{
+                .id_result = result_id,
+                .return_type = return_type,
+                .id_ref_2 = param_types,
+            });
+            break :blk result_id;
         },
         else => return self.todo("process type instruction {s}", .{@tagName(self.inst.opcode)}),
     };
 
-    return AsmValue{ .ty = ref };
+    return AsmValue{ .ty = id };
 }
 
 /// Emit `self.inst` into `self.spv` and `self.func`, and return the AsmValue
@@ -411,7 +416,7 @@ fn processGenericInstruction(self: *Assembler) !?AsmValue {
             .ref_id => |index| {
                 const result = try self.resolveRef(index);
                 try section.ensureUnusedCapacity(self.spv.gpa, 1);
-                section.writeOperand(spec.IdRef, result.resultId(self.spv));
+                section.writeOperand(spec.IdRef, result.resultId());
             },
             .string => |offset| {
                 const text = std.mem.sliceTo(self.inst.string_bytes.items[offset..], 0);
@@ -460,18 +465,9 @@ fn resolveRef(self: *Assembler, ref: AsmValue.Ref) !AsmValue {
     }
 }
 
-/// Resolve a value reference as type.
-fn resolveTypeRef(self: *Assembler, ref: AsmValue.Ref) !CacheRef {
+fn resolveRefId(self: *Assembler, ref: AsmValue.Ref) !IdRef {
     const value = try self.resolveRef(ref);
-    switch (value) {
-        .just_declared, .unresolved_forward_reference => unreachable,
-        .ty => |ty_ref| return ty_ref,
-        else => {
-            const name = self.value_map.keys()[ref];
-            // TODO: Improve source location.
-            return self.fail(0, "expected operand %{s} to refer to a type", .{name});
-        },
-    }
+    return value.resultId();
 }
 
 /// Attempt to parse an instruction into `self.inst`.
@@ -710,22 +706,41 @@ fn parseContextDependentNumber(self: *Assembler) !void {
     assert(self.inst.opcode == .OpConstant or self.inst.opcode == .OpSpecConstant);
 
     const tok = self.currentToken();
-    const result_type_ref = try self.resolveTypeRef(self.inst.operands.items[0].ref_id);
-    const result_type = self.spv.cache.lookup(result_type_ref);
-    switch (result_type) {
-        .int_type => |int| {
-            try self.parseContextDependentInt(int.signedness, int.bits);
-        },
-        .float_type => |float| {
-            switch (float.bits) {
+    const result = try self.resolveRef(self.inst.operands.items[0].ref_id);
+    const result_id = result.resultId();
+    // We are going to cheat a little bit: The types we are interested in, int and float,
+    // are added to the module and cached via self.spv.intType and self.spv.floatType. Therefore,
+    // we can determine the width of these types by directly checking the cache.
+    // This only works if the Assembler and codegen both use spv.intType and spv.floatType though.
+    // We don't expect there to be many of these types, so just look it up every time.
+    // TODO: Count be improved to be a little bit more efficent.
+
+    {
+        var it = self.spv.cache2.int_types.iterator();
+        while (it.next()) |entry| {
+            const id = entry.value_ptr.*;
+            if (id != result_id) continue;
+            const info = entry.key_ptr.*;
+            return try self.parseContextDependentInt(info.signedness, info.bits);
+        }
+    }
+
+    {
+        var it = self.spv.cache2.float_types.iterator();
+        while (it.next()) |entry| {
+            const id = entry.value_ptr.*;
+            if (id != result_id) continue;
+            const info = entry.key_ptr.*;
+            switch (info.bits) {
                 16 => try self.parseContextDependentFloat(16),
                 32 => try self.parseContextDependentFloat(32),
                 64 => try self.parseContextDependentFloat(64),
-                else => return self.fail(tok.start, "cannot parse {}-bit float literal", .{float.bits}),
+                else => return self.fail(tok.start, "cannot parse {}-bit info literal", .{info.bits}),
             }
-        },
-        else => return self.fail(tok.start, "cannot parse literal constant", .{}),
+        }
     }
+
+    return self.fail(tok.start, "cannot parse literal constant", .{});
 }
 
 fn parseContextDependentInt(self: *Assembler, signedness: std.builtin.Signedness, width: u32) !void {