Commit 688d7055e3

Robin Voetter <robin@voetter.nl>
2024-11-02 16:05:06
spirv: assembler hacky constant placeholders
1 parent b530155
Changed files (3)
src/codegen/spirv/Assembler.zig
@@ -45,6 +45,9 @@ const Token = struct {
         pipe,
         /// =.
         equals,
+        /// $identifier. This is used (for now) for constant values, like integers.
+        /// These can be used in place of a normal `value`.
+        placeholder,
 
         fn name(self: Tag) []const u8 {
             return switch (self) {
@@ -56,6 +59,7 @@ const Token = struct {
                 .string => "<string literal>",
                 .pipe => "'|'",
                 .equals => "'='",
+                .placeholder => "<placeholder>",
             };
         }
     };
@@ -128,12 +132,19 @@ const AsmValue = union(enum) {
     /// This result-value represents a type registered into the module's type system.
     ty: IdRef,
 
+    /// This is a pre-supplied constant integer value.
+    constant: u32,
+
     /// Retrieve the result-id of this AsmValue. Asserts that this AsmValue
     /// is of a variant that allows the result to be obtained (not an unresolved
     /// forward declaration, not in the process of being declared, etc).
     pub fn resultId(self: AsmValue) IdRef {
         return switch (self) {
-            .just_declared, .unresolved_forward_reference => unreachable,
+            .just_declared,
+            .unresolved_forward_reference,
+            // TODO: Lower this value as constant?
+            .constant,
+            => unreachable,
             .value => |result| result,
             .ty => |result| result,
         };
@@ -383,16 +394,16 @@ fn processGenericInstruction(self: *Assembler) !?AsmValue {
             .OpExecutionMode, .OpExecutionModeId => &self.spv.sections.execution_modes,
             .OpVariable => switch (@as(spec.StorageClass, @enumFromInt(operands[2].value))) {
                 .Function => &self.func.prologue,
-                // These don't need to be marked in the dependency system.
-                // Probably we should add them anyway, then filter out PushConstant globals.
-                .PushConstant => &self.spv.sections.types_globals_constants,
-                else => section: {
+                .Input, .Output => section: {
                     maybe_spv_decl_index = try self.spv.allocDecl(.global);
                     try self.func.decl_deps.put(self.spv.gpa, maybe_spv_decl_index.?, {});
                     // TODO: In theory this can be non-empty if there is an initializer which depends on another global...
                     try self.spv.declareDeclDeps(maybe_spv_decl_index.?, &.{});
                     break :section &self.spv.sections.types_globals_constants;
                 },
+                // These don't need to be marked in the dependency system.
+                // Probably we should add them anyway, then filter out PushConstant globals.
+                else => &self.spv.sections.types_globals_constants,
             },
             // Default case - to be worked out further.
             else => &self.func.body,
@@ -665,6 +676,22 @@ fn parseRefId(self: *Assembler) !void {
 
 fn parseLiteralInteger(self: *Assembler) !void {
     const tok = self.currentToken();
+    if (self.eatToken(.placeholder)) {
+        const name = self.tokenText(tok)[1..];
+        const value = self.value_map.get(name) orelse {
+            return self.fail(tok.start, "invalid placeholder '${s}'", .{name});
+        };
+        switch (value) {
+            .constant => |literal32| {
+                try self.inst.operands.append(self.gpa, .{ .literal32 = literal32 });
+            },
+            else => {
+                return self.fail(tok.start, "value '{s}' cannot be used as placeholder", .{name});
+            },
+        }
+        return;
+    }
+
     try self.expectToken(.value);
     // According to the SPIR-V machine readable grammar, a LiteralInteger
     // may consist of one or more words. From the SPIR-V docs it seems like there
@@ -680,6 +707,22 @@ fn parseLiteralInteger(self: *Assembler) !void {
 
 fn parseLiteralExtInstInteger(self: *Assembler) !void {
     const tok = self.currentToken();
+    if (self.eatToken(.placeholder)) {
+        const name = self.tokenText(tok)[1..];
+        const value = self.value_map.get(name) orelse {
+            return self.fail(tok.start, "invalid placeholder '${s}'", .{name});
+        };
+        switch (value) {
+            .constant => |literal32| {
+                try self.inst.operands.append(self.gpa, .{ .literal32 = literal32 });
+            },
+            else => {
+                return self.fail(tok.start, "value '{s}' cannot be used as placeholder", .{name});
+            },
+        }
+        return;
+    }
+
     try self.expectToken(.value);
     const text = self.tokenText(tok);
     const value = std.fmt.parseInt(u32, text, 0) catch {
@@ -756,6 +799,22 @@ fn parseContextDependentNumber(self: *Assembler) !void {
 
 fn parseContextDependentInt(self: *Assembler, signedness: std.builtin.Signedness, width: u32) !void {
     const tok = self.currentToken();
+    if (self.eatToken(.placeholder)) {
+        const name = self.tokenText(tok)[1..];
+        const value = self.value_map.get(name) orelse {
+            return self.fail(tok.start, "invalid placeholder '${s}'", .{name});
+        };
+        switch (value) {
+            .constant => |literal32| {
+                try self.inst.operands.append(self.gpa, .{ .literal32 = literal32 });
+            },
+            else => {
+                return self.fail(tok.start, "value '{s}' cannot be used as placeholder", .{name});
+            },
+        }
+        return;
+    }
+
     try self.expectToken(.value);
 
     if (width == 0 or width > 2 * @bitSizeOf(spec.Word)) {
@@ -903,6 +962,7 @@ fn nextToken(self: *Assembler, start_offset: u32) !Token {
         string,
         string_end,
         escape,
+        placeholder,
     } = .start;
     var token_start = start_offset;
     var offset = start_offset;
@@ -930,6 +990,10 @@ fn nextToken(self: *Assembler, start_offset: u32) !Token {
                     offset += 1;
                     break;
                 },
+                '$' => {
+                    state = .placeholder;
+                    tag = .placeholder;
+                },
                 else => {
                     state = .value;
                     tag = .value;
@@ -945,11 +1009,11 @@ fn nextToken(self: *Assembler, start_offset: u32) !Token {
                 ' ', '\t', '\r', '\n', '=', '|' => break,
                 else => {},
             },
-            .result_id => switch (c) {
+            .result_id, .placeholder => switch (c) {
                 '_', 'a'...'z', 'A'...'Z', '0'...'9' => {},
                 ' ', '\t', '\r', '\n', '=', '|' => break,
                 else => {
-                    try self.addError(offset, "illegal character in result-id", .{});
+                    try self.addError(offset, "illegal character in result-id or placeholder", .{});
                     // Again, probably a forgotten delimiter here.
                     break;
                 },
src/codegen/spirv.zig
@@ -6556,13 +6556,59 @@ const NavGen = struct {
             // for the string, we still use the next u32 for the null terminator.
             extra_i += (constraint.len + name.len + (2 + 3)) / 4;
 
-            if (self.typeOf(input).zigTypeTag(zcu) == .type) {
-                // This assembly input is a type instead of a value.
-                // That's fine for now, just make sure to resolve it as such.
-                const val = (try self.air.value(input, self.pt)).?;
-                const ty_id = try self.resolveType(val.toType(), .direct);
-                try as.value_map.put(as.gpa, name, .{ .ty = ty_id });
+            const input_ty = self.typeOf(input);
+
+            if (std.mem.eql(u8, constraint, "c")) {
+                // constant
+                const val = (try self.air.value(input, self.pt)) orelse {
+                    return self.fail("assembly inputs with 'c' constraint have to be compile-time known", .{});
+                };
+
+                // TODO: This entire function should be handled a bit better...
+                const ip = &zcu.intern_pool;
+                switch (ip.indexToKey(val.toIntern())) {
+                    .int_type,
+                    .ptr_type,
+                    .array_type,
+                    .vector_type,
+                    .opt_type,
+                    .anyframe_type,
+                    .error_union_type,
+                    .simple_type,
+                    .struct_type,
+                    .union_type,
+                    .opaque_type,
+                    .enum_type,
+                    .func_type,
+                    .error_set_type,
+                    .inferred_error_set_type,
+                    => unreachable, // types, not values
+
+                    .undef => return self.fail("assembly input with 'c' constraint cannot be undefined", .{}),
+
+                    .int => {
+                        try as.value_map.put(as.gpa, name, .{ .constant = @intCast(val.toUnsignedInt(zcu)) });
+                    },
+
+                    else => unreachable, // TODO
+                }
+            } else if (std.mem.eql(u8, constraint, "t")) {
+                // type
+                if (input_ty.zigTypeTag(zcu) == .type) {
+                    // This assembly input is a type instead of a value.
+                    // That's fine for now, just make sure to resolve it as such.
+                    const val = (try self.air.value(input, self.pt)).?;
+                    const ty_id = try self.resolveType(val.toType(), .direct);
+                    try as.value_map.put(as.gpa, name, .{ .ty = ty_id });
+                } else {
+                    const ty_id = try self.resolveType(input_ty, .direct);
+                    try as.value_map.put(as.gpa, name, .{ .ty = ty_id });
+                }
             } else {
+                if (input_ty.zigTypeTag(zcu) == .type) {
+                    return self.fail("use the 't' constraint to supply types to SPIR-V inline assembly", .{});
+                }
+
                 const val_id = try self.resolve(input);
                 try as.value_map.put(as.gpa, name, .{ .value = val_id });
             }
@@ -6624,6 +6670,7 @@ const NavGen = struct {
                 .just_declared, .unresolved_forward_reference => unreachable,
                 .ty => return self.fail("cannot return spir-v type as value from assembly", .{}),
                 .value => |ref| return ref,
+                .constant => return self.fail("cannot return constant from assembly", .{}),
             }
 
             // TODO: Multiple results
src/Sema.zig
@@ -17605,6 +17605,8 @@ fn analyzePtrArithmetic(
         } else break :rs ptr_src;
     };
 
+    try sema.requireRuntimeBlock(block, op_src, runtime_src);
+
     const target = zcu.getTarget();
     if (target_util.arePointersLogical(target, ptr_info.flags.address_space)) {
         return sema.failWithOwnedErrorMsg(block, msg: {
@@ -17623,7 +17625,6 @@ fn analyzePtrArithmetic(
         });
     }
 
-    try sema.requireRuntimeBlock(block, op_src, runtime_src);
     return block.addInst(.{
         .tag = air_tag,
         .data = .{ .ty_pl = .{