Commit aade6f1195

Robin Voetter <robin@voetter.nl>
2023-05-29 14:10:02
spirv: cache for floats
1 parent b2a984c
Changed files (1)
src
codegen
src/codegen/spirv/TypeConstantCache.zig
@@ -56,11 +56,40 @@ const Tag = enum {
     type_array,
 
     // -- Values
+    /// Value of type f16
+    /// data is value
+    float16,
+    /// Value of type f32
+    /// data is value
+    float32,
+    /// Value of type f64
+    /// data is payload to Float16
+    float64,
 
     const SimpleType = enum { void, bool };
 
     const VectorType = Key.VectorType;
     const ArrayType = Key.ArrayType;
+
+    const Float64 = struct {
+        // Low-order 32 bits of the value.
+        low: u32,
+        // High-order 32 bits of the value.
+        high: u32,
+
+        fn encode(value: f64) Float64 {
+            const bits = @bitCast(u64, value);
+            return .{
+                .low = @truncate(u32, bits),
+                .high = @truncate(u32, bits >> 32),
+            };
+        }
+
+        fn decode(self: Float64) f64 {
+            const bits = @as(u64, self.low) | (@as(u64, self.high) << 32);
+            return @bitCast(f64, bits);
+        }
+    };
 };
 
 pub const Ref = enum(u32) { _ };
@@ -79,6 +108,7 @@ pub const Key = union(enum) {
     array_type: ArrayType,
 
     // -- values
+    float: Float,
 
     pub const IntType = std.builtin.Type.Int;
     pub const FloatType = std.builtin.Type.Float;
@@ -98,9 +128,33 @@ pub const Key = union(enum) {
         stride: u32 = 0,
     };
 
+    /// Represents a numberic value of some type.
+    pub const Float = struct {
+        /// The type: 16, 32, or 64-bit float.
+        ty: Ref,
+        /// The actual value.
+        value: Value,
+
+        pub const Value = union(enum) {
+            float16: f16,
+            float32: f32,
+            float64: f64,
+        };
+    };
+
     fn hash(self: Key) u32 {
         var hasher = std.hash.Wyhash.init(0);
-        std.hash.autoHash(&hasher, self);
+        switch (self) {
+            .float => |float| {
+                std.hash.autoHash(&hasher, float.ty);
+                switch (float.value) {
+                    .float16 => |value| std.hash.autoHash(&hasher, @bitCast(u16, value)),
+                    .float32 => |value| std.hash.autoHash(&hasher, @bitCast(u32, value)),
+                    .float64 => |value| std.hash.autoHash(&hasher, @bitCast(u64, value)),
+                }
+            },
+            inline else => |key| std.hash.autoHash(&hasher, key),
+        }
         return @truncate(u32, hasher.final());
     }
 
@@ -141,7 +195,7 @@ pub fn deinit(self: *Self, spv: *const Module) void {
 /// This function returns a spir-v section of (only) constant and type instructions.
 /// Additionally, decorations, debug names, etc, are all directly emitted into the
 /// `spv` module. The section is allocated with `spv.gpa`.
-pub fn materialize(self: *Self, spv: *Module) !Section {
+pub fn materialize(self: *const Self, spv: *Module) !Section {
     var section = Section{};
     errdefer section.deinit(spv.gpa);
     for (self.items.items(.result_id), 0..) |result_id, index| {
@@ -151,7 +205,7 @@ pub fn materialize(self: *Self, spv: *Module) !Section {
 }
 
 fn emit(
-    self: *Self,
+    self: *const Self,
     spv: *Module,
     result_id: IdResult,
     ref: Ref,
@@ -206,9 +260,34 @@ fn emit(
                 try spv.decorate(result_id, .{ .ArrayStride = .{ .array_stride = array.stride } });
             }
         },
+        .float => |float| {
+            const ty_id = self.resultId(float.ty);
+            const lit: spec.LiteralContextDependentNumber = switch (float.value) {
+                .float16 => |value| .{ .uint32 = @bitCast(u16, value) },
+                .float32 => |value| .{ .float32 = value },
+                .float64 => |value| .{ .float64 = value },
+            };
+            try section.emit(spv.gpa, .OpConstant, .{
+                .id_result_type = ty_id,
+                .id_result = result_id,
+                .value = lit,
+            });
+        },
     }
 }
 
+/// Get the ref for a key that has already been added to the cache.
+fn get(self: *const Self, key: Key) Ref {
+    const adapter: Key.Adapter = .{ .self = self };
+    const index = self.map.getIndexAdapted(key, adapter).?;
+    return @intToEnum(Ref, index);
+}
+
+/// Get the result-id for a key that has already been added to the cache.
+fn getId(self: *const Self, key: Key) IdResult {
+    return self.resultId(self.get(key));
+}
+
 /// Add a key to this cache. Returns a reference to the key that
 /// was added. The corresponding result-id can be queried using
 /// self.resultId with the result.
@@ -251,6 +330,24 @@ pub fn resolve(self: *Self, spv: *Module, key: Key) !Ref {
             .result_id = result_id,
             .data = try self.addExtra(spv, array),
         },
+        .float => |float| switch (self.lookup(float.ty).float_type.bits) {
+            16 => .{
+                .tag = .float16,
+                .result_id = result_id,
+                .data = @bitCast(u16, float.value.float16),
+            },
+            32 => .{
+                .tag = .float32,
+                .result_id = result_id,
+                .data = @bitCast(u32, float.value.float32),
+            },
+            64 => .{
+                .tag = .float64,
+                .result_id = result_id,
+                .data = try self.addExtra(spv, Tag.Float64.encode(float.value.float64)),
+            },
+            else => unreachable,
+        },
     };
     try self.items.append(spv.gpa, item);
 
@@ -285,6 +382,18 @@ pub fn lookup(self: *const Self, ref: Ref) Key {
         } },
         .type_vector => .{ .vector_type = self.extraData(Tag.VectorType, data) },
         .type_array => .{ .array_type = self.extraData(Tag.ArrayType, data) },
+        .float16 => .{ .float = .{
+            .ty = self.get(.{ .float_type = .{ .bits = 16 } }),
+            .value = .{ .float16 = @bitCast(f16, @intCast(u16, data)) },
+        } },
+        .float32 => .{ .float = .{
+            .ty = self.get(.{ .float_type = .{ .bits = 32 } }),
+            .value = .{ .float32 = @bitCast(f32, data) },
+        } },
+        .float64 => .{ .float = .{
+            .ty = self.get(.{ .float_type = .{ .bits = 32 } }),
+            .value = .{ .float64 = self.extraData(Tag.Float64, data).decode() },
+        } },
     };
 }