Commit f858bf1616

Robin Voetter <robin@voetter.nl>
2023-10-08 12:09:25
spirv: air bitcast for non-numeric non-pointer types
1 parent 0af16a5
Changed files (3)
src
codegen
test
behavior
src/codegen/spirv/Cache.zig
@@ -435,6 +435,13 @@ pub const Key = union(enum) {
             else => unreachable,
         };
     }
+
+    pub fn isNumericalType(self: Key) bool {
+        return switch (self) {
+            .int_type, .float_type => true,
+            else => false,
+        };
+    }
 };
 
 pub fn deinit(self: *Self, spv: *const Module) void {
src/codegen/spirv.zig
@@ -2578,25 +2578,52 @@ const DeclGen = struct {
             return src_id;
         }
 
-        const result_id = self.spv.allocId();
-
         // TODO: Some more cases are missing here
         //   See fn bitCast in llvm.zig
 
         if (src_ty.zigTypeTag(mod) == .Int and dst_ty.isPtrAtRuntime(mod)) {
+            const result_id = self.spv.allocId();
             try self.func.body.emit(self.spv.gpa, .OpConvertUToPtr, .{
                 .id_result_type = self.typeId(dst_ty_ref),
                 .id_result = result_id,
                 .integer_value = src_id,
             });
-        } else {
+            return result_id;
+        }
+
+        // We can only use OpBitcast for specific conversions: between numerical types, and
+        // between pointers. If the resolved spir-v types fall into this category then emit OpBitcast,
+        // otherwise use a temporary and perform a pointer cast.
+        const src_key = self.spv.cache.lookup(src_ty_ref);
+        const dst_key = self.spv.cache.lookup(dst_ty_ref);
+
+        if ((src_key.isNumericalType() and dst_key.isNumericalType()) or (src_key == .ptr_type and dst_key == .ptr_type)) {
+            const result_id = self.spv.allocId();
             try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
                 .id_result_type = self.typeId(dst_ty_ref),
                 .id_result = result_id,
                 .operand = src_id,
             });
+            return result_id;
         }
-        return result_id;
+
+        const src_ptr_ty_ref = try self.spv.ptrType(src_ty_ref, .Function);
+        const dst_ptr_ty_ref = try self.spv.ptrType(dst_ty_ref, .Function);
+
+        const tmp_id = self.spv.allocId();
+        try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
+            .id_result_type = self.typeId(src_ptr_ty_ref),
+            .id_result = tmp_id,
+            .storage_class = .Function,
+        });
+        try self.store(src_ty, tmp_id, src_id, false);
+        const casted_ptr_id = self.spv.allocId();
+        try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
+            .id_result_type = self.typeId(dst_ptr_ty_ref),
+            .id_result = casted_ptr_id,
+            .operand = tmp_id,
+        });
+        return try self.load(dst_ty, casted_ptr_id, false);
     }
 
     fn airBitCast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
test/behavior/cast.zig
@@ -899,7 +899,6 @@ test "peer cast [:x]T to []T" {
 test "peer cast [N:x]T to [N]T" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -1728,7 +1727,6 @@ test "peer type resolution: error union and optional of same type" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
 
     const E = error{Foo};
     var a: E!*u8 = error.Foo;