Commit 8c72ad5320

Robin Voetter <robin@voetter.nl>
2023-05-29 15:25:09
spirv: cache for ints
1 parent aade6f1
Changed files (2)
src/codegen/spirv/TypeConstantCache.zig
@@ -56,6 +56,25 @@ const Tag = enum {
     type_array,
 
     // -- Values
+    /// Value of type u8
+    /// data is value
+    uint8,
+    /// Value of type u32
+    /// data is value
+    uint32,
+    // TODO: More specialized tags here.
+    /// Integer value for signed values that are smaller than 32 bits.
+    /// data is pointer to Int32
+    int_small,
+    /// Integer value for unsigned values that are smaller than 32 bits.
+    /// data is pointer to UInt32
+    uint_small,
+    /// Integer value for signed values that are beteen 32 and 64 bits.
+    /// data is pointer to Int64
+    int_large,
+    /// Integer value for unsinged values that are beteen 32 and 64 bits.
+    /// data is pointer to UInt64
+    uint_large,
     /// Value of type f16
     /// data is value
     float16,
@@ -90,6 +109,52 @@ const Tag = enum {
             return @bitCast(f64, bits);
         }
     };
+
+    const Int32 = struct {
+        ty: Ref,
+        value: i32,
+    };
+
+    const UInt32 = struct {
+        ty: Ref,
+        value: u32,
+    };
+
+    const UInt64 = struct {
+        ty: Ref,
+        low: u32,
+        high: u32,
+
+        fn encode(ty: Ref, value: u64) Int64 {
+            return .{
+                .ty = ty,
+                .low = @truncate(u32, value),
+                .high = @truncate(u32, value >> 32),
+            };
+        }
+
+        fn decode(self: UInt64) u64 {
+            return @as(u64, self.low) | (@as(u64, self.high) << 32);
+        }
+    };
+
+    const Int64 = struct {
+        ty: Ref,
+        low: u32,
+        high: u32,
+
+        fn encode(ty: Ref, value: i64) Int64 {
+            return .{
+                .ty = ty,
+                .low = @truncate(u32, @bitCast(u64, value)),
+                .high = @truncate(u32, @bitCast(u64, value) >> 32),
+            };
+        }
+
+        fn decode(self: Int64) i64 {
+            return @bitCast(i64, @as(u64, self.low) | (@as(u64, self.high) << 32));
+        }
+    };
 };
 
 pub const Ref = enum(u32) { _ };
@@ -108,6 +173,7 @@ pub const Key = union(enum) {
     array_type: ArrayType,
 
     // -- values
+    int: Int,
     float: Float,
 
     pub const IntType = std.builtin.Type.Int;
@@ -128,6 +194,41 @@ pub const Key = union(enum) {
         stride: u32 = 0,
     };
 
+    pub const Int = struct {
+        /// The type: any bitness integer.
+        ty: Ref,
+        /// The actual value. Only uint64 and int64 types
+        /// are available here: Smaller types should use these
+        /// fields.
+        value: Value,
+
+        pub const Value = union(enum) {
+            uint64: u64,
+            int64: i64,
+        };
+
+        /// Turns this value into the corresponding 32-bit literal, 2s complement signed.
+        fn toBits32(self: Int) u32 {
+            return switch (self.value) {
+                .uint64 => |val| @intCast(u32, val),
+                .int64 => |val| if (val < 0) @bitCast(u32, @intCast(i32, val)) else @intCast(u32, val),
+            };
+        }
+
+        fn toBits64(self: Int) u64 {
+            return switch (self.value) {
+                .uint64 => |val| val,
+                .int64 => |val| @bitCast(u64, val),
+            };
+        }
+
+        fn to(self: Int, comptime T: type) T {
+            return switch (self.value) {
+                inline else => |val| @intCast(T, val),
+            };
+        }
+    };
+
     /// Represents a numberic value of some type.
     pub const Float = struct {
         /// The type: 16, 32, or 64-bit float.
@@ -212,6 +313,7 @@ fn emit(
     section: *Section,
 ) !void {
     const key = self.lookup(ref);
+    const Lit = spec.LiteralContextDependentNumber;
     switch (key) {
         .void_type => {
             try section.emit(spv.gpa, .OpTypeVoid, .{ .id_result = result_id });
@@ -260,9 +362,24 @@ fn emit(
                 try spv.decorate(result_id, .{ .ArrayStride = .{ .array_stride = array.stride } });
             }
         },
+        .int => |int| {
+            const int_type = self.lookup(int.ty).int_type;
+            const ty_id = self.resultId(int.ty);
+            const lit: Lit = switch (int_type.bits) {
+                1...32 => .{ .uint32 = int.toBits32() },
+                33...64 => .{ .uint64 = int.toBits64() },
+                else => unreachable,
+            };
+
+            try section.emit(spv.gpa, .OpConstant, .{
+                .id_result_type = ty_id,
+                .id_result = result_id,
+                .value = lit,
+            });
+        },
         .float => |float| {
             const ty_id = self.resultId(float.ty);
-            const lit: spec.LiteralContextDependentNumber = switch (float.value) {
+            const lit: Lit = switch (float.value) {
                 .float16 => |value| .{ .uint32 = @bitCast(u16, value) },
                 .float32 => |value| .{ .float32 = value },
                 .float64 => |value| .{ .float64 = value },
@@ -330,6 +447,58 @@ pub fn resolve(self: *Self, spv: *Module, key: Key) !Ref {
             .result_id = result_id,
             .data = try self.addExtra(spv, array),
         },
+        .int => |int| blk: {
+            const int_type = self.lookup(int.ty).int_type;
+            if (int_type.signedness == .unsigned and int_type.bits == 8) {
+                break :blk .{
+                    .tag = .uint8,
+                    .result_id = result_id,
+                    .data = int.to(u8),
+                };
+            } else if (int_type.signedness == .unsigned and int_type.bits == 32) {
+                break :blk .{
+                    .tag = .uint32,
+                    .result_id = result_id,
+                    .data = int.to(u32),
+                };
+            }
+
+            switch (int.value) {
+                inline else => |val| {
+                    if (val >= 0 and val <= std.math.maxInt(u32)) {
+                        break :blk .{
+                            .tag = .uint_small,
+                            .result_id = result_id,
+                            .data = try self.addExtra(spv, Tag.UInt32{
+                                .ty = int.ty,
+                                .value = @intCast(u32, val),
+                            }),
+                        };
+                    } else if (val >= std.math.minInt(i32) and val <= std.math.maxInt(i32)) {
+                        break :blk .{
+                            .tag = .int_small,
+                            .result_id = result_id,
+                            .data = try self.addExtra(spv, Tag.Int32{
+                                .ty = int.ty,
+                                .value = @intCast(i32, val),
+                            }),
+                        };
+                    } else if (val < 0) {
+                        break :blk .{
+                            .tag = .int_large,
+                            .result_id = result_id,
+                            .data = try self.addExtra(spv, Tag.Int64.encode(int.ty, @intCast(i64, val))),
+                        };
+                    } else {
+                        break :blk .{
+                            .tag = .uint_large,
+                            .result_id = result_id,
+                            .data = try self.addExtra(spv, Tag.UInt64.encode(int.ty, @intCast(u64, val))),
+                        };
+                    }
+                },
+            }
+        },
         .float => |float| switch (self.lookup(float.ty).float_type.bits) {
             16 => .{
                 .tag = .float16,
@@ -391,9 +560,45 @@ pub fn lookup(self: *const Self, ref: Ref) Key {
             .value = .{ .float32 = @bitCast(f32, data) },
         } },
         .float64 => .{ .float = .{
-            .ty = self.get(.{ .float_type = .{ .bits = 32 } }),
+            .ty = self.get(.{ .float_type = .{ .bits = 64 } }),
             .value = .{ .float64 = self.extraData(Tag.Float64, data).decode() },
         } },
+        .uint8 => .{ .int = .{
+            .ty = self.get(.{ .int_type = .{ .signedness = .unsigned, .bits = 8 } }),
+            .value = .{ .uint64 = data },
+        } },
+        .uint32 => .{ .int = .{
+            .ty = self.get(.{ .int_type = .{ .signedness = .unsigned, .bits = 32 } }),
+            .value = .{ .uint64 = data },
+        } },
+        .int_small => {
+            const payload = self.extraData(Tag.Int32, data);
+            return .{ .int = .{
+                .ty = payload.ty,
+                .value = .{ .int64 = payload.value },
+            } };
+        },
+        .uint_small => {
+            const payload = self.extraData(Tag.UInt32, data);
+            return .{ .int = .{
+                .ty = payload.ty,
+                .value = .{ .uint64 = payload.value },
+            } };
+        },
+        .int_large => {
+            const payload = self.extraData(Tag.Int64, data);
+            return .{ .int = .{
+                .ty = payload.ty,
+                .value = .{ .int64 = payload.decode() },
+            } };
+        },
+        .uint_large => {
+            const payload = self.extraData(Tag.UInt64, data);
+            return .{ .int = .{
+                .ty = payload.ty,
+                .value = .{ .uint64 = payload.decode() },
+            } };
+        },
     };
 }
 
@@ -409,6 +614,7 @@ fn addExtraAssumeCapacity(self: *Self, extra: anytype) !u32 {
         const field_val = @field(extra, field.name);
         const word = switch (field.type) {
             u32 => field_val,
+            i32 => @bitCast(u32, field_val),
             Ref => @enumToInt(field_val),
             else => @compileError("Invalid type: " ++ @typeName(field.type)),
         };
@@ -428,6 +634,7 @@ fn extraDataTrail(self: Self, comptime T: type, offset: u32) struct { data: T, t
         const word = self.extra.items[offset + i];
         @field(result, field.name) = switch (field.type) {
             u32 => word,
+            i32 => @bitCast(i32, word),
             Ref => @intToEnum(Ref, word),
             else => @compileError("Invalid type: " ++ @typeName(field.type)),
         };
src/codegen/spirv.zig
@@ -1293,10 +1293,14 @@ pub const DeclGen = struct {
                 const total_len = std.math.cast(u32, ty.arrayLenIncludingSentinel()) orelse {
                     return self.fail("array type of {} elements is too large", .{ty.arrayLenIncludingSentinel()});
                 };
-                _ = total_len;
-                return self.spv.resolve(.{ .array_type = .{
+                const len_ty_ref = try self.intType2(.unsigned, 32);
+                const len_ref = try self.spv.resolve(.{ .int = .{
+                    .ty = len_ty_ref,
+                    .value = .{ .uint64 = total_len },
+                } });
+                return try self.spv.resolve(.{ .array_type = .{
                     .element_type = elem_ty_ref,
-                    .length = @intToEnum(SpvRef, 0),
+                    .length = len_ref,
                 } });
             },
             else => unreachable, // TODO