Commit 345d6e280d

Robin Voetter <robin@voetter.nl>
2024-01-21 01:41:41
spirv: air int_from_bool
1 parent 77ef78a
Changed files (3)
src
codegen
test
src/codegen/spirv.zig
@@ -2202,6 +2202,7 @@ const DeclGen = struct {
             .int_from_ptr    => try self.airIntFromPtr(inst),
             .float_from_int  => try self.airFloatFromInt(inst),
             .int_from_float  => try self.airIntFromFloat(inst),
+            .int_from_bool   => try self.airIntFromBool(inst),
             .fpext, .fptrunc => try self.airFloatCast(inst),
             .not             => try self.airNot(inst),
 
@@ -3174,50 +3175,64 @@ const DeclGen = struct {
         const mod = self.module;
         const src_ty_ref = try self.resolveType(src_ty, .direct);
         const dst_ty_ref = try self.resolveType(dst_ty, .direct);
-        if (src_ty_ref == dst_ty_ref) {
-            return src_id;
-        }
+        const src_key = self.spv.cache.lookup(src_ty_ref);
+        const dst_key = self.spv.cache.lookup(dst_ty_ref);
 
-        // TODO: Some more cases are missing here
-        //   See fn bitCast in llvm.zig
+        const result_id = blk: {
+            if (src_ty_ref == dst_ty_ref) {
+                break :blk src_id;
+            }
 
-        if (src_ty.zigTypeTag(mod) == .Int and dst_ty.isPtrAtRuntime(mod)) {
-            const result_id = self.spv.allocId();
-            try self.func.body.emit(self.spv.gpa, .OpConvertUToPtr, .{
-                .id_result_type = self.typeId(dst_ty_ref),
-                .id_result = result_id,
-                .integer_value = src_id,
-            });
-            return result_id;
-        }
+            // TODO: Some more cases are missing here
+            //   See fn bitCast in llvm.zig
 
-        // We can only use OpBitcast for specific conversions: between numerical types, and
-        // between pointers. If the resolved spir-v types fall into this category then emit OpBitcast,
-        // otherwise use a temporary and perform a pointer cast.
-        const src_key = self.spv.cache.lookup(src_ty_ref);
-        const dst_key = self.spv.cache.lookup(dst_ty_ref);
+            if (src_ty.zigTypeTag(mod) == .Int and dst_ty.isPtrAtRuntime(mod)) {
+                const result_id = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpConvertUToPtr, .{
+                    .id_result_type = self.typeId(dst_ty_ref),
+                    .id_result = result_id,
+                    .integer_value = src_id,
+                });
+                break :blk result_id;
+            }
+
+            // We can only use OpBitcast for specific conversions: between numerical types, and
+            // between pointers. If the resolved spir-v types fall into this category then emit OpBitcast,
+            // otherwise use a temporary and perform a pointer cast.
+            if ((src_key.isNumericalType() and dst_key.isNumericalType()) or (src_key == .ptr_type and dst_key == .ptr_type)) {
+                const result_id = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
+                    .id_result_type = self.typeId(dst_ty_ref),
+                    .id_result = result_id,
+                    .operand = src_id,
+                });
+
+                break :blk result_id;
+            }
 
-        if ((src_key.isNumericalType() and dst_key.isNumericalType()) or (src_key == .ptr_type and dst_key == .ptr_type)) {
-            const result_id = self.spv.allocId();
+            const dst_ptr_ty_ref = try self.ptrType(dst_ty, .Function);
+
+            const tmp_id = try self.alloc(src_ty, .{ .storage_class = .Function });
+            try self.store(src_ty, tmp_id, src_id, .{});
+            const casted_ptr_id = self.spv.allocId();
             try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
-                .id_result_type = self.typeId(dst_ty_ref),
-                .id_result = result_id,
-                .operand = src_id,
+                .id_result_type = self.typeId(dst_ptr_ty_ref),
+                .id_result = casted_ptr_id,
+                .operand = tmp_id,
             });
-            return result_id;
-        }
+            break :blk try self.load(dst_ty, casted_ptr_id, .{});
+        };
 
-        const dst_ptr_ty_ref = try self.ptrType(dst_ty, .Function);
+        // Because strange integers use sign-extended representation, we may need to normalize
+        // the result here.
+        // TODO: This detail could cause stuff like @as(*const i1, @ptrCast(&@as(u1, 1))) to break
+        // should we change the representation of strange integers?
+        if (dst_ty.zigTypeTag(mod) == .Int) {
+            const info = self.arithmeticTypeInfo(dst_ty);
+            return try self.normalize(dst_ty_ref, result_id, info);
+        }
 
-        const tmp_id = try self.alloc(src_ty, .{ .storage_class = .Function });
-        try self.store(src_ty, tmp_id, src_id, .{});
-        const casted_ptr_id = self.spv.allocId();
-        try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
-            .id_result_type = self.typeId(dst_ptr_ty_ref),
-            .id_result = casted_ptr_id,
-            .operand = tmp_id,
-        });
-        return try self.load(dst_ty, casted_ptr_id, .{});
+        return result_id;
     }
 
     fn airBitCast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -3340,6 +3355,22 @@ const DeclGen = struct {
         return result_id;
     }
 
+    fn airIntFromBool(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op;
+        const operand_id = try self.resolve(un_op);
+        const result_ty = self.typeOfIndex(inst);
+
+        var wip = try self.elementWise(result_ty);
+        defer wip.deinit();
+        for (wip.results, 0..) |*result_id, i| {
+            const elem_id = try wip.elementAt(Type.bool, operand_id, i);
+            result_id.* = try self.intFromBool(wip.scalar_ty_ref, elem_id);
+        }
+        return try wip.finalize();
+    }
+
     fn airFloatCast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
         if (self.liveness.isUnused(inst)) return null;
 
test/behavior/bool.zig
@@ -9,8 +9,6 @@ test "bool literals" {
 }
 
 test "cast bool to int" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     const t = true;
     const f = false;
     try expectEqual(@as(u32, 1), @intFromBool(t));
test/behavior/cast.zig
@@ -2430,7 +2430,6 @@ test "@intFromBool on vector" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {