Commit 06d9e3b2eb

Robin Voetter <robin@voetter.nl>
2023-09-17 02:54:53
spirv: always emit unsigned integers
This is required for SPIR-V in Kernel mode. The Intel implementation just didn't care about this fact.
1 parent 18d0909
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -408,7 +408,7 @@ pub const DeclGen = struct {
         switch (repr) {
             .indirect => {
                 const int_ty_ref = try self.intType(.unsigned, 1);
-                return self.spv.constInt(int_ty_ref, @intFromBool(value));
+                return self.constInt(int_ty_ref, @intFromBool(value));
             },
             .direct => {
                 const bool_ty_ref = try self.resolveType(Type.bool, .direct);
@@ -417,6 +417,25 @@ pub const DeclGen = struct {
         }
     }
 
+    /// Emits an integer constant.
+    /// This function, unlike SpvModule.constInt, takes care to bitcast
+    /// the value to an unsigned int first for Kernels.
+    fn constInt(self: *DeclGen, ty_ref: CacheRef, value: anytype) !IdRef {
+        if (value < 0) {
+            const ty = self.spv.cache.lookup(ty_ref).int_type;
+            // Manually truncate the value so that the resulting value
+            // fits within the unsigned type.
+            const bits: u64 = @bitCast(@as(i64, @intCast(value)));
+            const truncated_bits = if (ty.bits == 64)
+                bits
+            else
+                bits & (@as(u64, 1) << @intCast(ty.bits)) - 1;
+            return try self.spv.constInt(ty_ref, truncated_bits);
+        } else {
+            return try self.spv.constInt(ty_ref, value);
+        }
+    }
+
     /// Construct a struct at runtime.
     /// result_ty_ref must be a struct type.
     fn constructStruct(self: *DeclGen, result_ty_ref: CacheRef, constituents: []const IdRef) !IdRef {
@@ -434,7 +453,7 @@ pub const DeclGen = struct {
         const member_types = spv_composite_ty.member_types;
 
         for (constituents, member_types, 0..) |constitent_id, member_ty_ref, index| {
-            const index_id = try self.spv.constInt(index_ty_ref, index);
+            const index_id = try self.constInt(index_ty_ref, index);
             const ptr_member_ty_ref = try self.spv.ptrType(member_ty_ref, .Generic);
             const ptr_id = try self.accessChain(ptr_member_ty_ref, ptr_composite_id, &.{index_id});
             try self.func.body.emit(self.spv.gpa, .OpStore, .{
@@ -469,7 +488,7 @@ pub const DeclGen = struct {
         const ptr_elem_ty_ref = try self.spv.ptrType(elem_ty_ref, .Generic);
 
         for (constituents, 0..) |constitent_id, index| {
-            const index_id = try self.spv.constInt(index_ty_ref, index);
+            const index_id = try self.constInt(index_ty_ref, index);
             const ptr_id = try self.accessChain(ptr_elem_ty_ref, ptr_composite_id, &.{index_id});
             try self.func.body.emit(self.spv.gpa, .OpStore, .{
                 .pointer = ptr_id,
@@ -580,17 +599,14 @@ pub const DeclGen = struct {
                 .generic_poison,
                 => unreachable, // non-runtime values
 
-                .false, .true => switch (repr) {
-                    .direct => return try self.spv.constBool(result_ty_ref, val.toBool()),
-                    .indirect => return try self.spv.constInt(result_ty_ref, @intFromBool(val.toBool())),
-                },
+                .false, .true => return try self.constBool(val.toBool(), repr),
             },
 
             .int => {
                 if (ty.isSignedInt(mod)) {
-                    return try self.spv.constInt(result_ty_ref, val.toSignedInt(mod));
+                    return try self.constInt(result_ty_ref, val.toSignedInt(mod));
                 } else {
-                    return try self.spv.constInt(result_ty_ref, val.toUnsignedInt(mod));
+                    return try self.constInt(result_ty_ref, val.toUnsignedInt(mod));
                 }
             },
             .float => return switch (ty.floatBits(target)) {
@@ -602,7 +618,7 @@ pub const DeclGen = struct {
             },
             .err => |err| {
                 const value = try mod.getErrorValue(err.name);
-                return try self.spv.constInt(result_ty_ref, value);
+                return try self.constInt(result_ty_ref, value);
             },
             .error_union => |error_union| {
                 // TODO: Error unions may be constructed with constant instructions if the payload type
@@ -716,7 +732,7 @@ pub const DeclGen = struct {
                             // TODO: This is really space inefficient, perhaps there is a better
                             // way to do it?
                             for (bytes, 0..) |byte, i| {
-                                constituents[i] = try self.spv.constInt(elem_ty_ref, byte);
+                                constituents[i] = try self.constInt(elem_ty_ref, byte);
                             }
                         },
                         .elems => |elems| {
@@ -794,7 +810,7 @@ pub const DeclGen = struct {
                 const index_ty_ref = try self.intType(.unsigned, 32);
 
                 if (layout.tag_size != 0) {
-                    const index_id = try self.spv.constInt(index_ty_ref, @as(u32, @intCast(layout.tag_index)));
+                    const index_id = try self.constInt(index_ty_ref, @as(u32, @intCast(layout.tag_index)));
                     const tag_ty = ty.unionTagTypeSafety(mod).?;
                     const tag_ty_ref = try self.resolveType(tag_ty, .indirect);
                     const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, .Function);
@@ -807,7 +823,7 @@ pub const DeclGen = struct {
                 }
 
                 if (layout.active_field_size != 0) {
-                    const index_id = try self.spv.constInt(index_ty_ref, @as(u32, @intCast(layout.active_field_index)));
+                    const index_id = try self.constInt(index_ty_ref, @as(u32, @intCast(layout.active_field_index)));
                     const active_field_ty_ref = try self.resolveType(layout.active_field_ty, .indirect);
                     const active_field_ptr_ty_ref = try self.spv.ptrType(active_field_ty_ref, .Function);
                     const ptr_id = try self.accessChain(active_field_ptr_ty_ref, var_id, &.{index_id});
@@ -870,7 +886,9 @@ pub const DeclGen = struct {
             // An array of largestSupportedIntBits.
             return self.todo("Implement {s} composite int type of {} bits", .{ @tagName(signedness), bits });
         };
-        return self.spv.intType(signedness, backing_bits);
+        // Kernel only supports unsigned ints.
+        // TODO: Only do this with Kernels
+        return self.spv.intType(.unsigned, backing_bits);
     }
 
     /// Create an integer type that represents 'usize'.
@@ -1568,8 +1586,8 @@ pub const DeclGen = struct {
     }
 
     fn intFromBool(self: *DeclGen, result_ty_ref: CacheRef, condition_id: IdRef) !IdRef {
-        const zero_id = try self.spv.constInt(result_ty_ref, 0);
-        const one_id = try self.spv.constInt(result_ty_ref, 1);
+        const zero_id = try self.constInt(result_ty_ref, 0);
+        const one_id = try self.constInt(result_ty_ref, 1);
         const result_id = self.spv.allocId();
         try self.func.body.emit(self.spv.gpa, .OpSelect, .{
             .id_result_type = self.typeId(result_ty_ref),
@@ -1589,7 +1607,7 @@ pub const DeclGen = struct {
             .Bool => blk: {
                 const direct_bool_ty_ref = try self.resolveType(ty, .direct);
                 const indirect_bool_ty_ref = try self.resolveType(ty, .indirect);
-                const zero_id = try self.spv.constInt(indirect_bool_ty_ref, 0);
+                const zero_id = try self.constInt(indirect_bool_ty_ref, 0);
                 const result_id = self.spv.allocId();
                 try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
                     .id_result_type = self.typeId(direct_bool_ty_ref),
@@ -1832,7 +1850,7 @@ pub const DeclGen = struct {
     fn maskStrangeInt(self: *DeclGen, ty_ref: CacheRef, value_id: IdRef, bits: u16) !IdRef {
         const mask_value = if (bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(bits))) - 1;
         const result_id = self.spv.allocId();
-        const mask_id = try self.spv.constInt(ty_ref, mask_value);
+        const mask_id = try self.constInt(ty_ref, mask_value);
         try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{
             .id_result_type = self.typeId(ty_ref),
             .id_result = result_id,
@@ -1971,7 +1989,7 @@ pub const DeclGen = struct {
                 // Note that signed overflow is also wrapping in spir-v.
 
                 const rhs_lt_zero_id = self.spv.allocId();
-                const zero_id = try self.spv.constInt(operand_ty_ref, 0);
+                const zero_id = try self.constInt(operand_ty_ref, 0);
                 try self.func.body.emit(self.spv.gpa, .OpSLessThan, .{
                     .id_result_type = self.typeId(bool_ty_ref),
                     .id_result = rhs_lt_zero_id,
@@ -2540,7 +2558,7 @@ pub const DeclGen = struct {
                 .Packed => unreachable, // TODO
                 else => {
                     const field_index_ty_ref = try self.intType(.unsigned, 32);
-                    const field_index_id = try self.spv.constInt(field_index_ty_ref, field_index);
+                    const field_index_id = try self.constInt(field_index_ty_ref, field_index);
                     const result_ty_ref = try self.resolveType(result_ptr_ty, .direct);
                     return try self.accessChain(result_ty_ref, object_ptr, &.{field_index_id});
                 },
@@ -2822,7 +2840,7 @@ pub const DeclGen = struct {
             else
                 err_union_id;
 
-            const zero_id = try self.spv.constInt(err_ty_ref, 0);
+            const zero_id = try self.constInt(err_ty_ref, 0);
             const is_err_id = self.spv.allocId();
             try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
                 .id_result_type = self.typeId(bool_ty_ref),