Commit 408c117246

Robin Voetter <robin@voetter.nl>
2024-01-21 15:54:27
spirv: air is_(non_)null_ptr, optional_payload_ptr
1 parent 7dfd403
Changed files (4)
src
codegen
test
src/codegen/spirv.zig
@@ -2273,13 +2273,16 @@ const DeclGen = struct {
             .wrap_errunion_err => try self.airWrapErrUnionErr(inst),
             .wrap_errunion_payload => try self.airWrapErrUnionPayload(inst),
 
-            .is_null     => try self.airIsNull(inst, .is_null),
-            .is_non_null => try self.airIsNull(inst, .is_non_null),
-            .is_err      => try self.airIsErr(inst, .is_err),
-            .is_non_err  => try self.airIsErr(inst, .is_non_err),
+            .is_null         => try self.airIsNull(inst, false, .is_null),
+            .is_non_null     => try self.airIsNull(inst, false, .is_non_null),
+            .is_null_ptr     => try self.airIsNull(inst, true, .is_null),
+            .is_non_null_ptr => try self.airIsNull(inst, true, .is_non_null),
+            .is_err          => try self.airIsErr(inst, .is_err),
+            .is_non_err      => try self.airIsErr(inst, .is_non_err),
 
-            .optional_payload => try self.airUnwrapOptional(inst),
-            .wrap_optional    => try self.airWrapOptional(inst),
+            .optional_payload     => try self.airUnwrapOptional(inst),
+            .optional_payload_ptr => try self.airUnwrapOptionalPtr(inst),
+            .wrap_optional        => try self.airWrapOptional(inst),
 
             .assembly => try self.airAssembly(inst),
 
@@ -4726,20 +4729,24 @@ const DeclGen = struct {
         return try self.constructStruct(err_union_ty, &types, &members);
     }
 
-    fn airIsNull(self: *DeclGen, inst: Air.Inst.Index, pred: enum { is_null, is_non_null }) !?IdRef {
+    fn airIsNull(self: *DeclGen, inst: Air.Inst.Index, is_pointer: bool, pred: enum { is_null, is_non_null }) !?IdRef {
         if (self.liveness.isUnused(inst)) return null;
 
         const mod = self.module;
         const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op;
         const operand_id = try self.resolve(un_op);
-        const optional_ty = self.typeOf(un_op);
-
+        const operand_ty = self.typeOf(un_op);
+        const optional_ty = if (is_pointer) operand_ty.childType(mod) else operand_ty;
         const payload_ty = optional_ty.optionalChild(mod);
 
         const bool_ty_ref = try self.resolveType(Type.bool, .direct);
 
         if (optional_ty.optionalReprIsPayload(mod)) {
             // Pointer payload represents nullability: pointer or slice.
+            const loaded_id = if (is_pointer)
+                try self.load(optional_ty, operand_id, .{})
+            else
+                operand_id;
 
             const ptr_ty = if (payload_ty.isSlice(mod))
                 payload_ty.slicePtrFieldType(mod)
@@ -4747,9 +4754,9 @@ const DeclGen = struct {
                 payload_ty;
 
             const ptr_id = if (payload_ty.isSlice(mod))
-                try self.extractField(ptr_ty, operand_id, 0)
+                try self.extractField(ptr_ty, loaded_id, 0)
             else
-                operand_id;
+                loaded_id;
 
             const payload_ty_ref = try self.resolveType(ptr_ty, .direct);
             const null_id = try self.spv.constNull(payload_ty_ref);
@@ -4760,13 +4767,26 @@ const DeclGen = struct {
             return try self.cmp(op, Type.bool, ptr_ty, ptr_id, null_id);
         }
 
-        const is_non_null_id = if (payload_ty.hasRuntimeBitsIgnoreComptime(mod))
-            try self.extractField(Type.bool, operand_id, 1)
-        else
-            // Optional representation is bool indicating whether the optional is set
-            // Optionals with no payload are represented as an (indirect) bool, so convert
-            // it back to the direct bool here.
-            try self.convertToDirect(Type.bool, operand_id);
+        const is_non_null_id = blk: {
+            if (is_pointer) {
+                if (payload_ty.hasRuntimeBitsIgnoreComptime(mod)) {
+                    const storage_class = spvStorageClass(operand_ty.ptrAddressSpace(mod));
+                    const bool_ptr_ty = try self.ptrType(Type.bool, storage_class);
+                    const tag_ptr_id = try self.accessChain(bool_ptr_ty, operand_id, &.{1});
+                    break :blk try self.load(Type.bool, tag_ptr_id, .{});
+                }
+
+                break :blk try self.load(Type.bool, operand_id, .{});
+            }
+
+            break :blk if (payload_ty.hasRuntimeBitsIgnoreComptime(mod))
+                try self.extractField(Type.bool, operand_id, 1)
+            else
+                // Optional representation is bool indicating whether the optional is set
+                // Optionals with no payload are represented as an (indirect) bool, so convert
+                // it back to the direct bool here.
+                try self.convertToDirect(Type.bool, operand_id);
+        };
 
         return switch (pred) {
             .is_null => blk: {
@@ -4837,6 +4857,32 @@ const DeclGen = struct {
         return try self.extractField(payload_ty, operand_id, 0);
     }
 
+    fn airUnwrapOptionalPtr(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const mod = self.module;
+        const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
+        const operand_id = try self.resolve(ty_op.operand);
+        const operand_ty = self.typeOf(ty_op.operand);
+        const optional_ty = operand_ty.childType(mod);
+        const payload_ty = optional_ty.optionalChild(mod);
+        const result_ty = self.typeOfIndex(inst);
+        const result_ty_ref = try self.resolveType(result_ty, .direct);
+
+        if (!payload_ty.hasRuntimeBitsIgnoreComptime(mod)) {
+            // There is no payload, but we still need to return a valid pointer.
+            // We can just return anything here, so just return a pointer to the operand.
+            return try self.bitCast(result_ty, operand_ty, operand_id);
+        }
+
+        if (optional_ty.optionalReprIsPayload(mod)) {
+            // They are the same value.
+            return try self.bitCast(result_ty, operand_ty, operand_id);
+        }
+
+        return try self.accessChain(result_ty_ref, operand_id, &.{0});
+    }
+
     fn airWrapOptional(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
         if (self.liveness.isUnused(inst)) return null;
 
test/behavior/cast.zig
@@ -1247,7 +1247,6 @@ test "implicit cast from *[N]T to ?[*]T" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var x: ?[*]u16 = null;
     var y: [4]u16 = [4]u16{ 0, 1, 2, 3 };
test/behavior/null.zig
@@ -32,7 +32,6 @@ test "test maybe object and get a pointer to the inner value" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var maybe_bool: ?bool = true;
 
@@ -142,7 +141,6 @@ test "if var maybe pointer" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try expect(shouldBeAPlus1(Particle{
         .a = 14,
test/behavior/optional.zig
@@ -72,7 +72,6 @@ test "address of unwrap optional" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         const Foo = struct {
@@ -341,7 +340,6 @@ test "optional pointer to zero bit optional payload" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const B = struct {
         fn foo(_: *@This()) void {}
@@ -518,7 +516,6 @@ test "copied optional doesn't alias source" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var opt_x: ?[3]f32 = [_]f32{0.0} ** 3;