Commit e05ace7673

Robin Voetter <robin@voetter.nl>
2023-05-29 17:22:31
spirv: cache function prototypes
1 parent 8c72ad5
Changed files (2)
src/codegen/spirv/TypeConstantCache.zig
@@ -54,6 +54,9 @@ const Tag = enum {
     /// Array type
     /// data is payload to ArrayType
     type_array,
+    /// Function (proto)type.
+    /// data is payload to FunctionType
+    type_function,
 
     // -- Values
     /// Value of type u8
@@ -90,6 +93,13 @@ const Tag = enum {
     const VectorType = Key.VectorType;
     const ArrayType = Key.ArrayType;
 
+    // Trailing:
+    // - [param_len]Ref: parameter types
+    const FunctionType = struct {
+        param_len: u32,
+        return_type: Ref,
+    };
+
     const Float64 = struct {
         // Low-order 32 bits of the value.
         low: u32,
@@ -171,6 +181,7 @@ pub const Key = union(enum) {
     float_type: FloatType,
     vector_type: VectorType,
     array_type: ArrayType,
+    function_type: FunctionType,
 
     // -- values
     int: Int,
@@ -194,6 +205,11 @@ pub const Key = union(enum) {
         stride: u32 = 0,
     };
 
+    pub const FunctionType = struct {
+        return_type: Ref,
+        parameters: []const Ref,
+    };
+
     pub const Int = struct {
         /// The type: any bitness integer.
         ty: Ref,
@@ -254,13 +270,33 @@ pub const Key = union(enum) {
                     .float64 => |value| std.hash.autoHash(&hasher, @bitCast(u64, value)),
                 }
             },
+            .function_type => |func| {
+                std.hash.autoHash(&hasher, func.return_type);
+                for (func.parameters) |param_type| {
+                    std.hash.autoHash(&hasher, param_type);
+                }
+            },
             inline else => |key| std.hash.autoHash(&hasher, key),
         }
         return @truncate(u32, hasher.final());
     }
 
     fn eql(a: Key, b: Key) bool {
-        return std.meta.eql(a, b);
+        const KeyTag = @typeInfo(Key).Union.tag_type.?;
+        const a_tag: KeyTag = a;
+        const b_tag: KeyTag = b;
+        if (a_tag != b_tag) {
+            return false;
+        }
+        return switch (a) {
+            .function_type => |a_func| {
+                const b_func = a.function_type;
+                return a_func.return_type == b_func.return_type and
+                    std.mem.eql(Ref, a_func.parameters, b_func.parameters);
+            },
+            // TODO: Unroll?
+            else => std.meta.eql(a, b),
+        };
     }
 
     pub const Adapter = struct {
@@ -362,6 +398,14 @@ fn emit(
                 try spv.decorate(result_id, .{ .ArrayStride = .{ .array_stride = array.stride } });
             }
         },
+        .function_type => |function| {
+            try section.emitRaw(spv.gpa, .OpTypeFunction, 2 + function.parameters.len);
+            section.writeOperand(IdResult, result_id);
+            section.writeOperand(IdResult, self.resultId(function.return_type));
+            for (function.parameters) |param_type| {
+                section.writeOperand(IdResult, self.resultId(param_type));
+            }
+        },
         .int => |int| {
             const int_type = self.lookup(int.ty).int_type;
             const ty_id = self.resultId(int.ty);
@@ -393,18 +437,6 @@ fn emit(
     }
 }
 
-/// Get the ref for a key that has already been added to the cache.
-fn get(self: *const Self, key: Key) Ref {
-    const adapter: Key.Adapter = .{ .self = self };
-    const index = self.map.getIndexAdapted(key, adapter).?;
-    return @intToEnum(Ref, index);
-}
-
-/// Get the result-id for a key that has already been added to the cache.
-fn getId(self: *const Self, key: Key) IdResult {
-    return self.resultId(self.get(key));
-}
-
 /// Add a key to this cache. Returns a reference to the key that
 /// was added. The corresponding result-id can be queried using
 /// self.resultId with the result.
@@ -447,6 +479,18 @@ pub fn resolve(self: *Self, spv: *Module, key: Key) !Ref {
             .result_id = result_id,
             .data = try self.addExtra(spv, array),
         },
+        .function_type => |function| blk: {
+            const extra = try self.addExtra(spv, Tag.FunctionType{
+                .param_len = @intCast(u32, function.parameters.len),
+                .return_type = function.return_type,
+            });
+            try self.extra.appendSlice(spv.gpa, @ptrCast([]const u32, function.parameters));
+            break :blk .{
+                .tag = .type_function,
+                .result_id = result_id,
+                .data = extra,
+            };
+        },
         .int => |int| blk: {
             const int_type = self.lookup(int.ty).int_type;
             if (int_type.signedness == .unsigned and int_type.bits == 8) {
@@ -523,13 +567,8 @@ pub fn resolve(self: *Self, spv: *Module, key: Key) !Ref {
     return @intToEnum(Ref, entry.index);
 }
 
-/// Look op the result-id that corresponds to a particular
-/// ref.
-pub fn resultId(self: Self, ref: Ref) IdResult {
-    return self.items.items(.result_id)[@enumToInt(ref)];
-}
-
 /// Turn a Ref back into a Key.
+/// The Key is valid until the next call to resolve().
 pub fn lookup(self: *const Self, ref: Ref) Key {
     const item = self.items.get(@enumToInt(ref));
     const data = item.data;
@@ -551,6 +590,15 @@ pub fn lookup(self: *const Self, ref: Ref) Key {
         } },
         .type_vector => .{ .vector_type = self.extraData(Tag.VectorType, data) },
         .type_array => .{ .array_type = self.extraData(Tag.ArrayType, data) },
+        .type_function => {
+            const payload = self.extraDataTrail(Tag.FunctionType, data);
+            return .{
+                .function_type = .{
+                    .return_type = payload.data.return_type,
+                    .parameters = @ptrCast([]const Ref, self.extra.items[payload.trail..][0..payload.data.param_len]),
+                },
+            };
+        },
         .float16 => .{ .float = .{
             .ty = self.get(.{ .float_type = .{ .bits = 16 } }),
             .value = .{ .float16 = @bitCast(f16, @intCast(u16, data)) },
@@ -602,6 +650,19 @@ pub fn lookup(self: *const Self, ref: Ref) Key {
     };
 }
 
+/// Look op the result-id that corresponds to a particular
+/// ref.
+pub fn resultId(self: Self, ref: Ref) IdResult {
+    return self.items.items(.result_id)[@enumToInt(ref)];
+}
+
+/// Get the ref for a key that has already been added to the cache.
+fn get(self: *const Self, key: Key) Ref {
+    const adapter: Key.Adapter = .{ .self = self };
+    const index = self.map.getIndexAdapted(key, adapter).?;
+    return @intToEnum(Ref, index);
+}
+
 fn addExtra(self: *Self, spv: *Module, extra: anytype) !u32 {
     const fields = @typeInfo(@TypeOf(extra)).Struct.fields;
     try self.extra.ensureUnusedCapacity(spv.gpa, fields.len);
src/codegen/spirv.zig
@@ -1177,6 +1177,10 @@ pub const DeclGen = struct {
         return try self.intType(.unsigned, self.getTarget().ptrBitWidth());
     }
 
+    fn sizeType2(self: *DeclGen) !SpvRef {
+        return try self.intType2(.unsigned, self.getTarget().ptrBitWidth());
+    }
+
     /// Generate a union type, optionally with a known field. If the tag alignment is greater
     /// than that of the payload, a regular union (non-packed, with both tag and payload), will
     /// be generated as follows:
@@ -1303,6 +1307,31 @@ pub const DeclGen = struct {
                     .length = len_ref,
                 } });
             },
+            .Fn => switch (repr) {
+                .direct => {
+                    // TODO: Put this somewhere in Sema.zig
+                    if (ty.fnIsVarArgs())
+                        return self.fail("VarArgs functions are unsupported for SPIR-V", .{});
+
+                    const param_ty_refs = try self.gpa.alloc(SpvRef, ty.fnParamLen());
+                    defer self.gpa.free(param_ty_refs);
+                    for (param_ty_refs, 0..) |*param_type, i| {
+                        param_type.* = try self.resolveType2(ty.fnParamType(i), .direct);
+                    }
+                    const return_ty_ref = try self.resolveType2(ty.fnReturnType(), .direct);
+
+                    return try self.spv.resolve(.{ .function_type = .{
+                        .return_type = return_ty_ref,
+                        .parameters = param_ty_refs,
+                    } });
+                },
+                .indirect => {
+                    // TODO: Represent function pointers properly.
+                    // For now, just use an usize type.
+                    return try self.sizeType2();
+                },
+            },
+
             else => unreachable, // TODO
         }
     }