Commit 3eafe3033e

Robin Voetter <robin@voetter.nl>
2022-11-26 12:23:07
spirv: improve storage efficiency for integer and float types
In practice there are only a few variations of these types allowed, so it kind-of makes sense to write them all out. Because the types are hashed this does not actually save all that many bytes in the long run, though. Perhaps some of these types should be pre-registered?
1 parent 5826a8a
Changed files (4)
src/codegen/spirv/Assembler.zig
@@ -266,27 +266,28 @@ fn processTypeInstruction(self: *Assembler) !AsmValue {
         .OpTypeVoid => SpvType.initTag(.void),
         .OpTypeBool => SpvType.initTag(.bool),
         .OpTypeInt => blk: {
-            const payload = try self.spv.arena.create(SpvType.Payload.Int);
             const signedness: std.builtin.Signedness = switch (operands[2].literal32) {
                 0 => .unsigned,
                 1 => .signed,
                 else => {
                     // TODO: Improve source location.
-                    return self.fail(0, "'{}' is not a valid signedness (expected 0 or 1)", .{operands[2].literal32});
+                    return self.fail(0, "{} is not a valid signedness (expected 0 or 1)", .{operands[2].literal32});
                 },
             };
-            payload.* = .{
-                .width = operands[1].literal32,
-                .signedness = signedness,
+            const width = std.math.cast(u16, operands[1].literal32) orelse {
+                return self.fail(0, "int type of {} bits is too large", .{operands[1].literal32});
             };
-            break :blk SpvType.initPayload(&payload.base);
+            break :blk try SpvType.int(self.spv.arena, signedness, width);
         },
         .OpTypeFloat => blk: {
-            const payload = try self.spv.arena.create(SpvType.Payload.Float);
-            payload.* = .{
-                .width = operands[1].literal32,
-            };
-            break :blk SpvType.initPayload(&payload.base);
+            const bits = operands[1].literal32;
+            switch (bits) {
+                16, 32, 64 => {},
+                else => {
+                    return self.fail(0, "{} is not a valid bit count for floats (expected 16, 32 or 64)", .{bits});
+                },
+            }
+            break :blk SpvType.float(@intCast(u16, bits));
         },
         .OpTypeVector => blk: {
             const payload = try self.spv.arena.create(SpvType.Payload.Vector);
@@ -754,21 +755,18 @@ fn parseContextDependentNumber(self: *Assembler) !void {
     const tok = self.currentToken();
     const result_type_ref = try self.resolveTypeRef(self.inst.operands.items[0].ref_id);
     const result_type = self.spv.type_cache.keys()[@enumToInt(result_type_ref)];
-    switch (result_type.tag()) {
-        .int => {
-            const int = result_type.castTag(.int).?;
-            try self.parseContextDependentInt(int.signedness, int.width);
-        },
-        .float => {
-            const width = result_type.castTag(.float).?.width;
-            switch (width) {
-                16 => try self.parseContextDependentFloat(16),
-                32 => try self.parseContextDependentFloat(32),
-                64 => try self.parseContextDependentFloat(64),
-                else => return self.fail(tok.start, "cannot parse {}-bit float literal", .{width}),
-            }
-        },
-        else => return self.fail(tok.start, "cannot parse literal constant {s}", .{@tagName(result_type.tag())}),
+    if (result_type.isInt()) {
+        try self.parseContextDependentInt(result_type.intSignedness(), result_type.intFloatBits());
+    } else if (result_type.isFloat()) {
+        const width = result_type.intFloatBits();
+        switch (width) {
+            16 => try self.parseContextDependentFloat(16),
+            32 => try self.parseContextDependentFloat(32),
+            64 => try self.parseContextDependentFloat(64),
+            else => return self.fail(tok.start, "cannot parse {}-bit float literal", .{width}),
+        }
+    } else {
+        return self.fail(tok.start, "cannot parse literal constant {s}", .{@tagName(result_type.tag())});
     }
 }
 
src/codegen/spirv/Module.zig
@@ -250,21 +250,30 @@ pub fn emitType(self: *Module, ty: Type) !IdResultType {
     switch (ty.tag()) {
         .void => try types.emit(self.gpa, .OpTypeVoid, result_id_operand),
         .bool => try types.emit(self.gpa, .OpTypeBool, result_id_operand),
-        .int => {
-            const signedness: spec.LiteralInteger = switch (ty.payload(.int).signedness) {
+        .u8,
+        .u16,
+        .u32,
+        .u64,
+        .i8,
+        .i16,
+        .i32,
+        .i64,
+        .int,
+        => {
+            const signedness: spec.LiteralInteger = switch (ty.intSignedness()) {
                 .unsigned => 0,
                 .signed => 1,
             };
 
             try types.emit(self.gpa, .OpTypeInt, .{
                 .id_result = result_id,
-                .width = ty.payload(.int).width,
+                .width = ty.intFloatBits(),
                 .signedness = signedness,
             });
         },
-        .float => try types.emit(self.gpa, .OpTypeFloat, .{
+        .f16, .f32, .f64 => try types.emit(self.gpa, .OpTypeFloat, .{
             .id_result = result_id,
-            .width = ty.payload(.float).width,
+            .width = ty.intFloatBits(),
         }),
         .vector => try types.emit(self.gpa, .OpTypeVector, .{
             .id_result = result_id,
src/codegen/spirv/type.zig
@@ -3,6 +3,8 @@
 
 const std = @import("std");
 const assert = std.debug.assert;
+const Signedness = std.builtin.Signedness;
+const Allocator = std.mem.Allocator;
 
 const spec = @import("spec.zig");
 
@@ -23,6 +25,41 @@ pub const Type = extern union {
         return .{ .ptr_otherwise = pl };
     }
 
+    pub fn int(arena: Allocator, signedness: Signedness, bits: u16) !Type {
+        const bits_and_signedness = switch (signedness) {
+            .signed => -@as(i32, bits),
+            .unsigned => @as(i32, bits),
+        };
+
+        return switch (bits_and_signedness) {
+            8 => initTag(.u8),
+            16 => initTag(.u16),
+            32 => initTag(.u32),
+            64 => initTag(.u64),
+            -8 => initTag(.i8),
+            -16 => initTag(.i16),
+            -32 => initTag(.i32),
+            -64 => initTag(.i64),
+            else => {
+                const int_payload = try arena.create(Payload.Int);
+                int_payload.* = .{
+                    .width = bits,
+                    .signedness = signedness,
+                };
+                return initPayload(&int_payload.base);
+            },
+        };
+    }
+
+    pub fn float(bits: u16) Type {
+        return switch (bits) {
+            16 => initTag(.f16),
+            32 => initTag(.f32),
+            64 => initTag(.f64),
+            else => unreachable, // Enable more types if required.
+        };
+    }
+
     pub fn tag(self: Type) Tag {
         if (@enumToInt(self.tag_if_small_enough) < Tag.no_payload_count) {
             return self.tag_if_small_enough;
@@ -80,9 +117,19 @@ pub const Type = extern union {
             .queue,
             .pipe_storage,
             .named_barrier,
+            .u8,
+            .u16,
+            .u32,
+            .u64,
+            .i8,
+            .i16,
+            .i32,
+            .i64,
+            .f16,
+            .f32,
+            .f64,
             => return true,
             .int,
-            .float,
             .vector,
             .matrix,
             .sampled_image,
@@ -132,6 +179,17 @@ pub const Type = extern union {
                     .queue,
                     .pipe_storage,
                     .named_barrier,
+                    .u8,
+                    .u16,
+                    .u32,
+                    .u64,
+                    .i8,
+                    .i16,
+                    .i32,
+                    .i64,
+                    .f16,
+                    .f32,
+                    .f64,
                     => {},
                     else => self.hashPayload(@field(Tag, field.name), &hasher),
                 }
@@ -185,6 +243,53 @@ pub const Type = extern union {
         };
     }
 
+    pub fn isInt(self: Type) bool {
+        return switch (self.tag()) {
+            .u8,
+            .u16,
+            .u32,
+            .u64,
+            .i8,
+            .i16,
+            .i32,
+            .i64,
+            .int,
+            => true,
+            else => false,
+        };
+    }
+
+    pub fn isFloat(self: Type) bool {
+        return switch (self.tag()) {
+            .f16, .f32, .f64 => true,
+            else => false,
+        };
+    }
+
+    /// Returns the number of bits that make up an int or float type.
+    /// Asserts type is either int or float.
+    pub fn intFloatBits(self: Type) u16 {
+        return switch (self.tag()) {
+            .u8, .i8 => 8,
+            .u16, .i16, .f16 => 16,
+            .u32, .i32, .f32 => 32,
+            .u64, .i64, .f64 => 64,
+            .int => self.payload(.int).width,
+            else => unreachable,
+        };
+    }
+
+    /// Returns the signedness of an integer type.
+    /// Asserts that the type is an int.
+    pub fn intSignedness(self: Type) Signedness {
+        return switch (self.tag()) {
+            .u8, .u16, .u32, .u64 => .unsigned,
+            .i8, .i16, .i32, .i64 => .signed,
+            .int => self.payload(.int).signedness,
+            else => unreachable,
+        };
+    }
+
     pub const Tag = enum(usize) {
         void,
         bool,
@@ -195,10 +300,20 @@ pub const Type = extern union {
         queue,
         pipe_storage,
         named_barrier,
+        u8,
+        u16,
+        u32,
+        u64,
+        i8,
+        i16,
+        i32,
+        i64,
+        f16,
+        f32,
+        f64,
 
         // After this, the tag requires a payload.
         int,
-        float,
         vector,
         matrix,
         image,
@@ -211,14 +326,33 @@ pub const Type = extern union {
         function,
         pipe,
 
-        pub const last_no_payload_tag = Tag.named_barrier;
+        pub const last_no_payload_tag = Tag.f64;
         pub const no_payload_count = @enumToInt(last_no_payload_tag) + 1;
 
         pub fn Type(comptime t: Tag) type {
             return switch (t) {
-                .void, .bool, .sampler, .event, .device_event, .reserve_id, .queue, .pipe_storage, .named_barrier => @compileError("Type Tag " ++ @tagName(t) ++ " has no payload"),
+                .void,
+                .bool,
+                .sampler,
+                .event,
+                .device_event,
+                .reserve_id,
+                .queue,
+                .pipe_storage,
+                .named_barrier,
+                .u8,
+                .u16,
+                .u32,
+                .u64,
+                .i8,
+                .i16,
+                .i32,
+                .i64,
+                .f16,
+                .f32,
+                .f64,
+                => @compileError("Type Tag " ++ @tagName(t) ++ " has no payload"),
                 .int => Payload.Int,
-                .float => Payload.Float,
                 .vector => Payload.Vector,
                 .matrix => Payload.Matrix,
                 .image => Payload.Image,
@@ -239,13 +373,8 @@ pub const Type = extern union {
 
         pub const Int = struct {
             base: Payload = .{ .tag = .int },
-            width: u32,
-            signedness: std.builtin.Signedness,
-        };
-
-        pub const Float = struct {
-            base: Payload = .{ .tag = .float },
-            width: u32,
+            width: u16,
+            signedness: Signedness,
         };
 
         pub const Vector = struct {
src/codegen/spirv.zig
@@ -451,12 +451,7 @@ pub const DeclGen = struct {
             return self.todo("Implement {s} composite int type of {} bits", .{ @tagName(signedness), bits });
         };
 
-        const payload = try self.spv.arena.create(SpvType.Payload.Int);
-        payload.* = .{
-            .width = backing_bits,
-            .signedness = signedness,
-        };
-        return try self.spv.resolveType(SpvType.initPayload(&payload.base));
+        return try self.spv.resolveType(try SpvType.int(self.spv.arena, signedness, backing_bits));
     }
 
     /// Turn a Zig type into a SPIR-V Type, and return a reference to it.
@@ -495,11 +490,7 @@ pub const DeclGen = struct {
                     return self.fail("Floating point width of {} bits is not supported for the current SPIR-V feature set", .{bits});
                 }
 
-                const payload = try self.spv.arena.create(SpvType.Payload.Float);
-                payload.* = .{
-                    .width = bits,
-                };
-                return try self.spv.resolveType(SpvType.initPayload(&payload.base));
+                return try self.spv.resolveType(SpvType.float(bits));
             },
             .Fn => {
                 // TODO: Put this somewhere in Sema.zig