Commit 77ef78a0ef

Robin Voetter <robin@voetter.nl>
2024-01-21 01:39:20
spirv: clean up arithmeticTypeInfo a bit
- No longer returns an error - Returns more useful vector info
1 parent 54ec936
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -373,8 +373,9 @@ const DeclGen = struct {
         /// For `composite_integer` this is 0 (TODO)
         backing_bits: u16,
 
-        /// Whether the type is a vector.
-        is_vector: bool,
+        /// Null if this type is a scalar, or the length
+        /// of the vector otherwise.
+        vector_len: ?u32,
 
         /// Whether the inner type is signed. Only relevant for integers.
         signedness: std.builtin.Signedness,
@@ -597,32 +598,37 @@ const DeclGen = struct {
         return self.backingIntBits(ty) == null;
     }
 
-    fn arithmeticTypeInfo(self: *DeclGen, ty: Type) !ArithmeticTypeInfo {
+    fn arithmeticTypeInfo(self: *DeclGen, ty: Type) ArithmeticTypeInfo {
         const mod = self.module;
         const target = self.getTarget();
-        return switch (ty.zigTypeTag(mod)) {
+        var scalar_ty = ty.scalarType(mod);
+        if (scalar_ty.zigTypeTag(mod) == .Enum) {
+            scalar_ty = scalar_ty.intTagType(mod);
+        }
+        const vector_len = if (ty.isVector(mod)) ty.vectorLen(mod) else null;
+        return switch (scalar_ty.zigTypeTag(mod)) {
             .Bool => ArithmeticTypeInfo{
                 .bits = 1, // Doesn't matter for this class.
                 .backing_bits = self.backingIntBits(1).?,
-                .is_vector = false,
+                .vector_len = vector_len,
                 .signedness = .unsigned, // Technically, but doesn't matter for this class.
                 .class = .bool,
             },
             .Float => ArithmeticTypeInfo{
-                .bits = ty.floatBits(target),
-                .backing_bits = ty.floatBits(target), // TODO: F80?
-                .is_vector = false,
+                .bits = scalar_ty.floatBits(target),
+                .backing_bits = scalar_ty.floatBits(target), // TODO: F80?
+                .vector_len = vector_len,
                 .signedness = .signed, // Technically, but doesn't matter for this class.
                 .class = .float,
             },
             .Int => blk: {
-                const int_info = ty.intInfo(mod);
+                const int_info = scalar_ty.intInfo(mod);
                 // TODO: Maybe it's useful to also return this value.
                 const maybe_backing_bits = self.backingIntBits(int_info.bits);
                 break :blk ArithmeticTypeInfo{
                     .bits = int_info.bits,
                     .backing_bits = maybe_backing_bits orelse 0,
-                    .is_vector = false,
+                    .vector_len = vector_len,
                     .signedness = int_info.signedness,
                     .class = if (maybe_backing_bits) |backing_bits|
                         if (backing_bits == int_info.bits)
@@ -633,22 +639,9 @@ const DeclGen = struct {
                         .composite_integer,
                 };
             },
-            .Enum => return self.arithmeticTypeInfo(ty.intTagType(mod)),
-            // As of yet, there is no vector support in the self-hosted compiler.
-            .Vector => blk: {
-                const child_type = ty.childType(mod);
-                const child_ty_info = try self.arithmeticTypeInfo(child_type);
-                break :blk ArithmeticTypeInfo{
-                    .bits = child_ty_info.bits,
-                    .backing_bits = child_ty_info.backing_bits,
-                    .is_vector = true,
-                    .signedness = child_ty_info.signedness,
-                    .class = child_ty_info.class,
-                };
-            },
-            // TODO: For which types is this the case?
-            // else => self.todo("implement arithmeticTypeInfo for {}", .{ty.fmt(self.module)}),
-            else => unreachable,
+            .Enum => unreachable,
+            .Vector => unreachable,
+            else => unreachable, // Unhandled arithmetic type
         };
     }
 
@@ -2336,7 +2329,7 @@ const DeclGen = struct {
         const shift_ty = self.typeOf(bin_op.rhs);
         const scalar_shift_ty_ref = try self.resolveType(shift_ty.scalarType(mod), .direct);
 
-        const info = try self.arithmeticTypeInfo(result_ty);
+        const info = self.arithmeticTypeInfo(result_ty);
         switch (info.class) {
             .composite_integer => return self.todo("shift ops for composite integers", .{}),
             .integer, .strange_integer => {},
@@ -2393,7 +2386,7 @@ const DeclGen = struct {
 
     fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef {
         const result_ty_ref = try self.resolveType(result_ty, .direct);
-        const info = try self.arithmeticTypeInfo(result_ty);
+        const info = self.arithmeticTypeInfo(result_ty);
 
         // TODO: Use fmin for OpenCL
         const cmp_id = try self.cmp(op, Type.bool, result_ty, lhs_id, rhs_id);
@@ -2516,7 +2509,7 @@ const DeclGen = struct {
     ) !IdRef {
         // Binary operations are generally applicable to both scalar and vector operations
         // in SPIR-V, but int and float versions of operations require different opcodes.
-        const info = try self.arithmeticTypeInfo(ty);
+        const info = self.arithmeticTypeInfo(ty);
 
         const opcode_index: usize = switch (info.class) {
             .composite_integer => {
@@ -2579,7 +2572,7 @@ const DeclGen = struct {
 
         const bool_ty_ref = try self.resolveType(Type.bool, .direct);
 
-        const info = try self.arithmeticTypeInfo(operand_ty);
+        const info = self.arithmeticTypeInfo(operand_ty);
         switch (info.class) {
             .composite_integer => return self.todo("overflow ops for composite integers", .{}),
             .strange_integer, .integer => {},
@@ -2693,7 +2686,7 @@ const DeclGen = struct {
 
         const bool_ty_ref = try self.resolveType(Type.bool, .direct);
 
-        const info = try self.arithmeticTypeInfo(operand_ty);
+        const info = self.arithmeticTypeInfo(operand_ty);
         switch (info.class) {
             .composite_integer => return self.todo("overflow shift for composite integers", .{}),
             .integer, .strange_integer => {},
@@ -2777,7 +2770,7 @@ const DeclGen = struct {
         const scalar_ty_ref = try self.resolveType(scalar_ty, .direct);
         const scalar_ty_id = self.typeId(scalar_ty_ref);
 
-        const info = try self.arithmeticTypeInfo(operand_ty);
+        const info = self.arithmeticTypeInfo(operand_ty);
 
         var result_id = try self.extractField(scalar_ty, operand, 0);
         const len = operand_ty.vectorLen(mod);
@@ -3093,7 +3086,7 @@ const DeclGen = struct {
         };
 
         const opcode: Opcode = opcode: {
-            const info = try self.arithmeticTypeInfo(op_ty);
+            const info = self.arithmeticTypeInfo(op_ty);
             const signedness = switch (info.class) {
                 .composite_integer => {
                     return self.todo("binary operations for composite integers", .{});
@@ -3245,8 +3238,8 @@ const DeclGen = struct {
         const dst_ty = self.typeOfIndex(inst);
         const dst_ty_ref = try self.resolveType(dst_ty, .direct);
 
-        const src_info = try self.arithmeticTypeInfo(src_ty);
-        const dst_info = try self.arithmeticTypeInfo(dst_ty);
+        const src_info = self.arithmeticTypeInfo(src_ty);
+        const dst_info = self.arithmeticTypeInfo(dst_ty);
 
         if (src_info.backing_bits == dst_info.backing_bits) {
             return operand_id;
@@ -3302,7 +3295,7 @@ const DeclGen = struct {
         const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
         const operand_ty = self.typeOf(ty_op.operand);
         const operand_id = try self.resolve(ty_op.operand);
-        const operand_info = try self.arithmeticTypeInfo(operand_ty);
+        const operand_info = self.arithmeticTypeInfo(operand_ty);
         const dest_ty = self.typeOfIndex(inst);
         const dest_ty_id = try self.resolveTypeId(dest_ty);
 
@@ -3328,7 +3321,7 @@ const DeclGen = struct {
         const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
         const operand_id = try self.resolve(ty_op.operand);
         const dest_ty = self.typeOfIndex(inst);
-        const dest_info = try self.arithmeticTypeInfo(dest_ty);
+        const dest_info = self.arithmeticTypeInfo(dest_ty);
         const dest_ty_id = try self.resolveTypeId(dest_ty);
 
         const result_id = self.spv.allocId();
@@ -3369,7 +3362,7 @@ const DeclGen = struct {
         const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
         const operand_id = try self.resolve(ty_op.operand);
         const result_ty = self.typeOfIndex(inst);
-        const info = try self.arithmeticTypeInfo(result_ty);
+        const info = self.arithmeticTypeInfo(result_ty);
 
         var wip = try self.elementWise(result_ty);
         defer wip.deinit();