Commit 9785014938

Ali Chraghi <alichraghi@proton.me>
2024-02-19 00:48:28
spirv: OpExtInstImport in assembler
1 parent 0f75143
Changed files (4)
src/codegen/spirv/Assembler.zig
@@ -256,10 +256,15 @@ fn todo(self: *Assembler, comptime fmt: []const u8, args: anytype) Error {
 /// If this function returns `error.AssembleFail`, an explanatory
 /// error message has already been emitted into `self.errors`.
 fn processInstruction(self: *Assembler) !void {
-    const result = switch (self.inst.opcode) {
+    const result: AsmValue = switch (self.inst.opcode) {
         .OpEntryPoint => {
             return self.fail(0, "cannot export entry points via OpEntryPoint, export the kernel using callconv(.Kernel)", .{});
         },
+        .OpExtInstImport => blk: {
+            const set_name_offset = self.inst.operands.items[1].string;
+            const set_name = std.mem.sliceTo(self.inst.string_bytes.items[set_name_offset..], 0);
+            break :blk .{ .value = try self.spv.importInstructionSet(set_name) };
+        },
         else => switch (self.inst.opcode.class()) {
             .TypeDeclaration => try self.processTypeInstruction(),
             else => if (try self.processGenericInstruction()) |result|
@@ -309,7 +314,7 @@ fn processTypeInstruction(self: *Assembler) !AsmValue {
                     return self.fail(0, "{} is not a valid bit count for floats (expected 16, 32 or 64)", .{bits});
                 },
             }
-            break :blk try self.spv.resolve(.{ .float_type = .{ .bits = @as(u16, @intCast(bits)) } });
+            break :blk try self.spv.resolve(.{ .float_type = .{ .bits = @intCast(bits) } });
         },
         .OpTypeVector => try self.spv.resolve(.{ .vector_type = .{
             .component_type = try self.resolveTypeRef(operands[1].ref_id),
@@ -364,6 +369,7 @@ fn processGenericInstruction(self: *Assembler) !?AsmValue {
             .OpExecutionMode, .OpExecutionModeId => &self.spv.sections.execution_modes,
             .OpVariable => switch (@as(spec.StorageClass, @enumFromInt(operands[2].value))) {
                 .Function => &self.func.prologue,
+                .UniformConstant => &self.spv.sections.types_globals_constants,
                 else => {
                     // This is currently disabled because global variables are required to be
                     // emitted in the proper order, and this should be honored in inline assembly
@@ -473,14 +479,14 @@ fn parseInstruction(self: *Assembler) !void {
     self.inst.string_bytes.shrinkRetainingCapacity(0);
 
     const lhs_result_tok = self.currentToken();
-    const maybe_lhs_result = if (self.eatToken(.result_id_assign)) blk: {
+    const maybe_lhs_result: ?AsmValue.Ref = if (self.eatToken(.result_id_assign)) blk: {
         const name = self.tokenText(lhs_result_tok)[1..];
         const entry = try self.value_map.getOrPut(self.gpa, name);
         try self.expectToken(.equals);
         if (!entry.found_existing) {
             entry.value_ptr.* = .just_declared;
         }
-        break :blk @as(AsmValue.Ref, @intCast(entry.index));
+        break :blk @intCast(entry.index);
     } else null;
 
     const opcode_tok = self.currentToken();
@@ -550,6 +556,7 @@ fn parseOperand(self: *Assembler, kind: spec.OperandKind) Error!void {
             .LiteralInteger => try self.parseLiteralInteger(),
             .LiteralString => try self.parseString(),
             .LiteralContextDependentNumber => try self.parseContextDependentNumber(),
+            .LiteralExtInstInteger => try self.parseLiteralExtInstInteger(),
             .PairIdRefIdRef => try self.parsePhiSource(),
             else => return self.todo("parse operand of type {s}", .{@tagName(kind)}),
         },
@@ -641,7 +648,7 @@ fn parseRefId(self: *Assembler) !void {
         entry.value_ptr.* = .unresolved_forward_reference;
     }
 
-    const index = @as(AsmValue.Ref, @intCast(entry.index));
+    const index: AsmValue.Ref = @intCast(entry.index);
     try self.inst.operands.append(self.gpa, .{ .ref_id = index });
 }
 
@@ -660,6 +667,16 @@ fn parseLiteralInteger(self: *Assembler) !void {
     try self.inst.operands.append(self.gpa, .{ .literal32 = value });
 }
 
+fn parseLiteralExtInstInteger(self: *Assembler) !void {
+    const tok = self.currentToken();
+    try self.expectToken(.value);
+    const text = self.tokenText(tok);
+    const value = std.fmt.parseInt(u32, text, 0) catch {
+        return self.fail(tok.start, "'{s}' is not a valid 32-bit integer literal", .{text});
+    };
+    try self.inst.operands.append(self.gpa, .{ .literal32 = value });
+}
+
 fn parseString(self: *Assembler) !void {
     const tok = self.currentToken();
     try self.expectToken(.string);
@@ -673,7 +690,7 @@ fn parseString(self: *Assembler) !void {
     else
         text[1..];
 
-    const string_offset = @as(u32, @intCast(self.inst.string_bytes.items.len));
+    const string_offset: u32 = @intCast(self.inst.string_bytes.items.len);
     try self.inst.string_bytes.ensureUnusedCapacity(self.gpa, literal.len + 1);
     self.inst.string_bytes.appendSliceAssumeCapacity(literal);
     self.inst.string_bytes.appendAssumeCapacity(0);
@@ -730,9 +747,9 @@ fn parseContextDependentInt(self: *Assembler, signedness: std.builtin.Signedness
 
         // Note, we store the sign-extended version here.
         if (width <= @bitSizeOf(spec.Word)) {
-            try self.inst.operands.append(self.gpa, .{ .literal32 = @as(u32, @truncate(@as(u128, @bitCast(int)))) });
+            try self.inst.operands.append(self.gpa, .{ .literal32 = @truncate(@as(u128, @bitCast(int))) });
         } else {
-            try self.inst.operands.append(self.gpa, .{ .literal64 = @as(u64, @truncate(@as(u128, @bitCast(int)))) });
+            try self.inst.operands.append(self.gpa, .{ .literal64 = @truncate(@as(u128, @bitCast(int))) });
         }
         return;
     }
@@ -753,7 +770,7 @@ fn parseContextDependentFloat(self: *Assembler, comptime width: u16) !void {
         return self.fail(tok.start, "'{s}' is not a valid {}-bit float literal", .{ text, width });
     };
 
-    const float_bits = @as(Int, @bitCast(value));
+    const float_bits: Int = @bitCast(value);
     if (width <= @bitSizeOf(spec.Word)) {
         try self.inst.operands.append(self.gpa, .{ .literal32 = float_bits });
     } else {
src/codegen/spirv/Module.zig
@@ -500,9 +500,9 @@ pub fn declPtr(self: *Module, index: Decl.Index) *Decl {
 
 /// Declare ALL dependencies for a decl.
 pub fn declareDeclDeps(self: *Module, decl_index: Decl.Index, deps: []const Decl.Index) !void {
-    const begin_dep = @as(u32, @intCast(self.decl_deps.items.len));
+    const begin_dep: u32 = @intCast(self.decl_deps.items.len);
     try self.decl_deps.appendSlice(self.gpa, deps);
-    const end_dep = @as(u32, @intCast(self.decl_deps.items.len));
+    const end_dep: u32 = @intCast(self.decl_deps.items.len);
 
     const decl = self.declPtr(decl_index);
     decl.begin_dep = begin_dep;
src/codegen/spirv/Section.zig
@@ -115,8 +115,8 @@ pub fn writeWords(section: *Section, words: []const Word) void {
 
 pub fn writeDoubleWord(section: *Section, dword: DoubleWord) void {
     section.writeWords(&.{
-        @as(Word, @truncate(dword)),
-        @as(Word, @truncate(dword >> @bitSizeOf(Word))),
+        @truncate(dword),
+        @truncate(dword >> @bitSizeOf(Word)),
     });
 }
 
@@ -196,12 +196,12 @@ fn writeString(section: *Section, str: []const u8) void {
 
 fn writeContextDependentNumber(section: *Section, operand: spec.LiteralContextDependentNumber) void {
     switch (operand) {
-        .int32 => |int| section.writeWord(@as(Word, @bitCast(int))),
-        .uint32 => |int| section.writeWord(@as(Word, @bitCast(int))),
-        .int64 => |int| section.writeDoubleWord(@as(DoubleWord, @bitCast(int))),
-        .uint64 => |int| section.writeDoubleWord(@as(DoubleWord, @bitCast(int))),
-        .float32 => |float| section.writeWord(@as(Word, @bitCast(float))),
-        .float64 => |float| section.writeDoubleWord(@as(DoubleWord, @bitCast(float))),
+        .int32 => |int| section.writeWord(@bitCast(int)),
+        .uint32 => |int| section.writeWord(@bitCast(int)),
+        .int64 => |int| section.writeDoubleWord(@bitCast(int)),
+        .uint64 => |int| section.writeDoubleWord(@bitCast(int)),
+        .float32 => |float| section.writeWord(@bitCast(float)),
+        .float64 => |float| section.writeDoubleWord(@bitCast(float)),
     }
 }
 
@@ -274,8 +274,8 @@ fn operandSize(comptime Operand: type, operand: Operand) usize {
         spec.LiteralString => std.math.divCeil(usize, operand.len + 1, @sizeOf(Word)) catch unreachable, // Add one for zero-terminator
 
         spec.LiteralContextDependentNumber => switch (operand) {
-            .int32, .uint32, .float32 => @as(usize, 1),
-            .int64, .uint64, .float64 => @as(usize, 2),
+            .int32, .uint32, .float32 => 1,
+            .int64, .uint64, .float64 => 2,
         },
 
         // TODO: Where this type is used (OpSpecConstantOp) is currently not correct in the spec
src/codegen/spirv.zig
@@ -1016,7 +1016,7 @@ const DeclGen = struct {
                     const elem_ty = Type.fromInterned(array_type.child);
                     const elem_ty_ref = try self.resolveType(elem_ty, .indirect);
 
-                    const constituents = try self.gpa.alloc(IdRef, @as(u32, @intCast(ty.arrayLenIncludingSentinel(mod))));
+                    const constituents = try self.gpa.alloc(IdRef, @intCast(ty.arrayLenIncludingSentinel(mod)));
                     defer self.gpa.free(constituents);
 
                     switch (aggregate.storage) {
@@ -1736,7 +1736,6 @@ const DeclGen = struct {
             .EnumLiteral,
             .ComptimeFloat,
             .ComptimeInt,
-            .Type,
             => unreachable, // Must be comptime.
 
             else => |tag| return self.todo("Implement zig type '{}'", .{tag}),
@@ -2323,18 +2322,10 @@ const DeclGen = struct {
 
             .div_float,
             .div_float_optimized,
-            // TODO: Check that this is the right operation.
             .div_trunc,
-            .div_trunc_optimized,
-            => try self.airArithOp(inst, .OpFDiv, .OpSDiv, .OpUDiv),
-            // TODO: Check if this is the right operation
-            .rem,
-            .rem_optimized,
-            => try self.airArithOp(inst, .OpFRem, .OpSRem, .OpSRem),
-            // TODO: Check if this is the right operation
-            .mod,
-            .mod_optimized,
-            => try self.airArithOp(inst, .OpFMod, .OpSMod, .OpSMod),
+            .div_trunc_optimized => try self.airArithOp(inst, .OpFDiv, .OpSDiv, .OpUDiv),
+            .rem, .rem_optimized => try self.airArithOp(inst, .OpFRem, .OpSRem, .OpSRem),
+            .mod, .mod_optimized => try self.airArithOp(inst, .OpFMod, .OpSMod, .OpSMod),
 
 
             .add_with_overflow => try self.airAddSubOverflow(inst, .OpIAdd, .OpULessThan, .OpSLessThan),
@@ -2348,7 +2339,7 @@ const DeclGen = struct {
 
             .splat => try self.airSplat(inst),
             .reduce, .reduce_optimized => try self.airReduce(inst),
-            .shuffle => try self.airShuffle(inst),
+            .shuffle                   => try self.airShuffle(inst),
 
             .ptr_add => try self.airPtrAdd(inst),
             .ptr_sub => try self.airPtrSub(inst),
@@ -2742,8 +2733,8 @@ const DeclGen = struct {
             else => unreachable,
         };
         const set_id = switch (target.os.tag) {
-            .opencl => try self.spv.importInstructionSet(.opencl),
-            .vulkan => try self.spv.importInstructionSet(.glsl),
+            .opencl => try self.spv.importInstructionSet("OpenCL.std"),
+            .vulkan => try self.spv.importInstructionSet("GLSL.std.450"),
             else => unreachable,
         };
 
@@ -2796,8 +2787,8 @@ const DeclGen = struct {
                 return self.todo("binary operations for composite integers", .{});
             },
             .integer, .strange_integer => switch (info.signedness) {
-                .signed => @as(usize, 1),
-                .unsigned => @as(usize, 2),
+                .signed => 1,
+                .unsigned => 2,
             },
             .float => 0,
             .bool => unreachable,
@@ -5357,7 +5348,7 @@ const DeclGen = struct {
                 const backing_bits = self.backingIntBits(bits) orelse {
                     return self.todo("implement composite int switch", .{});
                 };
-                break :blk if (backing_bits <= 32) @as(u32, 1) else 2;
+                break :blk if (backing_bits <= 32) 1 else 2;
             },
             .Enum => blk: {
                 const int_ty = cond_ty.intTagType(mod);
@@ -5365,7 +5356,7 @@ const DeclGen = struct {
                 const backing_bits = self.backingIntBits(int_info.bits) orelse {
                     return self.todo("implement composite int switch", .{});
                 };
-                break :blk if (backing_bits <= 32) @as(u32, 1) else 2;
+                break :blk if (backing_bits <= 32) 1 else 2;
             },
             .Pointer => blk: {
                 cond_indirect = try self.intFromPtr(cond_indirect);
@@ -5419,7 +5410,7 @@ const DeclGen = struct {
             for (0..num_cases) |case_i| {
                 // SPIR-V needs a literal here, which' width depends on the case condition.
                 const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
-                const items = @as([]const Air.Inst.Ref, @ptrCast(self.air.extra[case.end..][0..case.data.items_len]));
+                const items: []const Air.Inst.Ref = @ptrCast(self.air.extra[case.end..][0..case.data.items_len]);
                 const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len];
                 extra_index = case.end + case.data.items_len + case_body.len;
 
@@ -5428,7 +5419,7 @@ const DeclGen = struct {
                 for (items) |item| {
                     const value = (try self.air.value(item, mod)) orelse unreachable;
                     const int_val: u64 = switch (cond_ty.zigTypeTag(mod)) {
-                        .Bool, .Int => if (cond_ty.isSignedInt(mod)) @as(u64, @bitCast(value.toSignedInt(mod))) else value.toUnsignedInt(mod),
+                        .Bool, .Int => if (cond_ty.isSignedInt(mod)) @bitCast(value.toSignedInt(mod)) else value.toUnsignedInt(mod),
                         .Enum => blk: {
                             // TODO: figure out of cond_ty is correct (something with enum literals)
                             break :blk (try value.intFromEnum(cond_ty, mod)).toUnsignedInt(mod); // TODO: composite integer constants
@@ -5550,14 +5541,14 @@ const DeclGen = struct {
         const extra = self.air.extraData(Air.Asm, ty_pl.payload);
 
         const is_volatile = @as(u1, @truncate(extra.data.flags >> 31)) != 0;
-        const clobbers_len = @as(u31, @truncate(extra.data.flags));
+        const clobbers_len: u31 = @truncate(extra.data.flags);
 
         if (!is_volatile and self.liveness.isUnused(inst)) return null;
 
         var extra_i: usize = extra.end;
-        const outputs = @as([]const Air.Inst.Ref, @ptrCast(self.air.extra[extra_i..][0..extra.data.outputs_len]));
+        const outputs: []const Air.Inst.Ref = @ptrCast(self.air.extra[extra_i..][0..extra.data.outputs_len]);
         extra_i += outputs.len;
-        const inputs = @as([]const Air.Inst.Ref, @ptrCast(self.air.extra[extra_i..][0..extra.data.inputs_len]));
+        const inputs: []const Air.Inst.Ref = @ptrCast(self.air.extra[extra_i..][0..extra.data.inputs_len]);
         extra_i += inputs.len;
 
         if (outputs.len > 1) {
@@ -5679,7 +5670,7 @@ const DeclGen = struct {
         const mod = self.module;
         const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op;
         const extra = self.air.extraData(Air.Call, pl_op.payload);
-        const args = @as([]const Air.Inst.Ref, @ptrCast(self.air.extra[extra.end..][0..extra.data.args_len]));
+        const args: []const Air.Inst.Ref = @ptrCast(self.air.extra[extra.end..][0..extra.data.args_len]);
         const callee_ty = self.typeOf(pl_op.operand);
         const zig_fn_ty = switch (callee_ty.zigTypeTag(mod)) {
             .Fn => callee_ty,