Commit 83ab1ba8fd

Robin Voetter <robin@voetter.nl>
2023-04-11 22:13:54
spirv: lower air is_null, is_non_null
Implements AIR lowering for is_null and is_non_null tags. Additionally this cleans up and centralizes the logic to convert from 'direct' representation to 'indirect' representation and vice-versa. The related functions, as well as the functions that use it, are all moved near eachother so that the conversion logic remains in a central place. Extracting/inserting fields and loading/storing pointers should go through these functions.
1 parent 435a566
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -402,9 +402,21 @@ pub const DeclGen = struct {
         return result_id;
     }
 
-    fn constUndef(self: *DeclGen, ty_ref: SpvType.Ref) Error!IdRef {
+    fn constUndef(self: *DeclGen, ty_ref: SpvType.Ref) !IdRef {
         const result_id = self.spv.allocId();
-        try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpUndef, .{ .id_result_type = self.typeId(ty_ref), .id_result = result_id });
+        try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpUndef, .{
+            .id_result_type = self.typeId(ty_ref),
+            .id_result = result_id,
+        });
+        return result_id;
+    }
+
+    fn constNull(self: *DeclGen, ty_ref: SpvType.Ref) !IdRef {
+        const result_id = self.spv.allocId();
+        try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpConstantNull, .{
+            .id_result_type = self.typeId(ty_ref),
+            .id_result = result_id,
+        });
         return result_id;
     }
 
@@ -674,7 +686,7 @@ pub const DeclGen = struct {
                         try self.addConstBool(has_payload);
                         return;
                     } else if (ty.optionalReprIsPayload()) {
-                        // Optional representation is a nullable pointer.
+                        // Optional representation is a nullable pointer or slice.
                         if (val.castTag(.opt_payload)) |payload| {
                             try self.lower(payload_ty, payload.data);
                         } else if (has_payload) {
@@ -1257,7 +1269,7 @@ pub const DeclGen = struct {
 
                 const payload_ty_ref = try self.resolveType(payload_ty, .indirect);
                 if (ty.optionalReprIsPayload()) {
-                    // Optional is actually a pointer.
+                    // Optional is actually a pointer or a slice.
                     return payload_ty_ref;
                 }
 
@@ -1523,6 +1535,93 @@ pub const DeclGen = struct {
         }
     }
 
+    /// Convert representation from indirect (in memory) to direct (in 'register')
+    /// This converts the argument type from resolveType(ty, .indirect) to resolveType(ty, .direct).
+    fn convertToDirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef {
+        // const direct_ty_ref = try self.resolveType(ty, .direct);
+        return switch (ty.zigTypeTag()) {
+            .Bool => blk: {
+                const direct_bool_ty_ref = try self.resolveType(ty, .direct);
+                const indirect_bool_ty_ref = try self.resolveType(ty, .indirect);
+                const zero_id = try self.constInt(indirect_bool_ty_ref, 0);
+                const result_id = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
+                    .id_result_type = self.typeId(direct_bool_ty_ref),
+                    .id_result = result_id,
+                    .operand_1 = operand_id,
+                    .operand_2 = zero_id,
+                });
+                break :blk result_id;
+            },
+            else => operand_id,
+        };
+    }
+
+    /// Convert representation from direct (in 'register) to direct (in memory)
+    /// This converts the argument type from resolveType(ty, .direct) to resolveType(ty, .indirect).
+    fn convertToIndirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef {
+        return switch (ty.zigTypeTag()) {
+            .Bool => blk: {
+                const indirect_bool_ty_ref = try self.resolveType(ty, .indirect);
+                const zero_id = try self.constInt(indirect_bool_ty_ref, 0);
+                const one_id = try self.constInt(indirect_bool_ty_ref, 1);
+                const result_id = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpSelect, .{
+                    .id_result_type = self.typeId(indirect_bool_ty_ref),
+                    .id_result = result_id,
+                    .condition = operand_id,
+                    .object_1 = one_id,
+                    .object_2 = zero_id,
+                });
+                break :blk result_id;
+            },
+            else => operand_id,
+        };
+    }
+
+    fn extractField(self: *DeclGen, result_ty: Type, object: IdRef, field: u32) !IdRef {
+        const result_ty_ref = try self.resolveType(result_ty, .indirect);
+        const result_id = self.spv.allocId();
+        const indexes = [_]u32{field};
+        try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{
+            .id_result_type = self.typeId(result_ty_ref),
+            .id_result = result_id,
+            .composite = object,
+            .indexes = &indexes,
+        });
+        // Convert bools; direct structs have their field types as indirect values.
+        return try self.convertToDirect(result_ty, result_id);
+    }
+
+    fn load(self: *DeclGen, ptr_ty: Type, ptr_id: IdRef) !IdRef {
+        const value_ty = ptr_ty.childType();
+        const indirect_value_ty_ref = try self.resolveType(value_ty, .indirect);
+        const result_id = self.spv.allocId();
+        const access = spec.MemoryAccess.Extended{
+            .Volatile = ptr_ty.isVolatilePtr(),
+        };
+        try self.func.body.emit(self.spv.gpa, .OpLoad, .{
+            .id_result_type = self.typeId(indirect_value_ty_ref),
+            .id_result = result_id,
+            .pointer = ptr_id,
+            .memory_access = access,
+        });
+        return try self.convertToDirect(value_ty, result_id);
+    }
+
+    fn store(self: *DeclGen, ptr_ty: Type, ptr_id: IdRef, value_id: IdRef) !void {
+        const value_ty = ptr_ty.childType();
+        const indirect_value_id = try self.convertToIndirect(value_ty, value_id);
+        const access = spec.MemoryAccess.Extended{
+            .Volatile = ptr_ty.isVolatilePtr(),
+        };
+        try self.func.body.emit(self.spv.gpa, .OpStore, .{
+            .pointer = ptr_id,
+            .object = indirect_value_id,
+            .memory_access = access,
+        });
+    }
+
     fn genBody(self: *DeclGen, body: []const Air.Inst.Index) Error!void {
         for (body) |inst| {
             try self.genInst(inst);
@@ -1615,6 +1714,9 @@ pub const DeclGen = struct {
             .unwrap_errunion_err => try self.airErrUnionErr(inst),
             .wrap_errunion_err => try self.airWrapErrUnionErr(inst),
 
+            .is_null     => try self.airIsNull(inst, .is_null),
+            .is_non_null => try self.airIsNull(inst, .is_non_null),
+
             .assembly => try self.airAssembly(inst),
 
             .call              => try self.airCall(inst, .auto),
@@ -1776,18 +1878,17 @@ pub const DeclGen = struct {
             .float, .bool => unreachable,
         }
 
-        const operand_ty_id = try self.resolveTypeId(operand_ty);
-        const result_type_id = try self.resolveTypeId(result_ty);
-
-        const overflow_member_ty_ref = try self.intType(.unsigned, info.bits);
+        // The operand type must be the same as the result type in SPIR-V.
+        const operand_ty_ref = try self.resolveType(operand_ty, .direct);
+        const operand_ty_id = self.typeId(operand_ty_ref);
 
         const op_result_id = blk: {
             // Construct the SPIR-V result type.
             // It is almost the same as the zig one, except that the fields must be the same type
             // and they must be unsigned.
             const overflow_result_ty_ref = try self.spv.simpleStructType(&.{
-                .{ .ty = overflow_member_ty_ref, .name = "res" },
-                .{ .ty = overflow_member_ty_ref, .name = "ov" },
+                .{ .ty = operand_ty_ref, .name = "res" },
+                .{ .ty = operand_ty_ref, .name = "ov" },
             });
             const result_id = self.spv.allocId();
             try self.func.body.emit(self.spv.gpa, .OpIAddCarry, .{
@@ -1801,11 +1902,13 @@ pub const DeclGen = struct {
 
         // Now convert the SPIR-V flavor result into a Zig-flavor result.
         // First, extract the two fields.
-        const unsigned_result = try self.extractField(overflow_member_ty_ref, op_result_id, 0);
-        const overflow = try self.extractField(overflow_member_ty_ref, op_result_id, 1);
+        const unsigned_result = try self.extractField(operand_ty, op_result_id, 0);
+        const overflow = try self.extractField(operand_ty, op_result_id, 1);
 
         // We need to convert the results to the types that Zig expects here.
         // The `result` is the same type except unsigned, so we can just bitcast that.
+        // TODO: This can be removed in Kernels as there are only unsigned ints. Maybe for
+        // shaders as well?
         const result = try self.bitcast(operand_ty_id, unsigned_result);
 
         // The overflow needs to be converted into whatever is used to represent it in Zig.
@@ -1828,7 +1931,7 @@ pub const DeclGen = struct {
         const result_id = self.spv.allocId();
         const constituents = [_]IdRef{ result, casted_overflow };
         try self.func.body.emit(self.spv.gpa, .OpCompositeConstruct, .{
-            .id_result_type = result_type_id,
+            .id_result_type = operand_ty_id,
             .id_result = result_id,
             .constituents = &constituents,
         });
@@ -1980,25 +2083,14 @@ pub const DeclGen = struct {
         return result_id;
     }
 
-    fn extractField(self: *DeclGen, result_ty_ref: SpvType.Ref, object: IdRef, field: u32) !IdRef {
-        const result_id = self.spv.allocId();
-        const indexes = [_]u32{field};
-        try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{
-            .id_result_type = self.typeId(result_ty_ref),
-            .id_result = result_id,
-            .composite = object,
-            .indexes = &indexes,
-        });
-        // TODO: Convert bools, direct structs should have their field types as indirect values.
-        return result_id;
-    }
-
     fn airSliceField(self: *DeclGen, inst: Air.Inst.Index, field: u32) !?IdRef {
         if (self.liveness.isUnused(inst)) return null;
         const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+        const field_ty = self.air.typeOfIndex(inst);
+        const operand_id = try self.resolve(ty_op.operand);
         return try self.extractField(
-            try self.resolveType(self.air.typeOfIndex(inst), .direct),
-            try self.resolve(ty_op.operand),
+            field_ty,
+            operand_id,
             field,
         );
     }
@@ -2367,35 +2459,6 @@ pub const DeclGen = struct {
         return try self.load(ptr_ty, operand);
     }
 
-    fn load(self: *DeclGen, ptr_ty: Type, ptr: IdRef) !IdRef {
-        const value_ty = ptr_ty.childType();
-        const direct_result_ty_ref = try self.resolveType(value_ty, .direct);
-        const indirect_result_ty_ref = try self.resolveType(value_ty, .indirect);
-        const result_id = self.spv.allocId();
-        const access = spec.MemoryAccess.Extended{
-            .Volatile = ptr_ty.isVolatilePtr(),
-        };
-        try self.func.body.emit(self.spv.gpa, .OpLoad, .{
-            .id_result_type = self.typeId(indirect_result_ty_ref),
-            .id_result = result_id,
-            .pointer = ptr,
-            .memory_access = access,
-        });
-        if (value_ty.zigTypeTag() == .Bool) {
-            // Convert indirect bool to direct bool
-            const zero_id = try self.constInt(indirect_result_ty_ref, 0);
-            const casted_result_id = self.spv.allocId();
-            try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
-                .id_result_type = self.typeId(direct_result_ty_ref),
-                .id_result = casted_result_id,
-                .operand_1 = result_id,
-                .operand_2 = zero_id,
-            });
-            return casted_result_id;
-        }
-        return result_id;
-    }
-
     fn airStore(self: *DeclGen, inst: Air.Inst.Index) !void {
         const bin_op = self.air.instructions.items(.data)[inst].bin_op;
         const ptr_ty = self.air.typeOf(bin_op.lhs);
@@ -2405,35 +2468,6 @@ pub const DeclGen = struct {
         try self.store(ptr_ty, ptr, value);
     }
 
-    fn store(self: *DeclGen, ptr_ty: Type, ptr: IdRef, value: IdRef) !void {
-        const value_ty = ptr_ty.childType();
-        const converted_value = switch (value_ty.zigTypeTag()) {
-            .Bool => blk: {
-                const indirect_bool_ty_ref = try self.resolveType(value_ty, .indirect);
-                const result_id = self.spv.allocId();
-                const zero = try self.constInt(indirect_bool_ty_ref, 0);
-                const one = try self.constInt(indirect_bool_ty_ref, 1);
-                try self.func.body.emit(self.spv.gpa, .OpSelect, .{
-                    .id_result_type = self.typeId(indirect_bool_ty_ref),
-                    .id_result = result_id,
-                    .condition = value,
-                    .object_1 = one,
-                    .object_2 = zero,
-                });
-                break :blk result_id;
-            },
-            else => value,
-        };
-        const access = spec.MemoryAccess.Extended{
-            .Volatile = ptr_ty.isVolatilePtr(),
-        };
-        try self.func.body.emit(self.spv.gpa, .OpStore, .{
-            .pointer = ptr,
-            .object = converted_value,
-            .memory_access = access,
-        });
-    }
-
     fn airLoop(self: *DeclGen, inst: Air.Inst.Index) !void {
         const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
         const loop = self.air.extraData(Air.Block, ty_pl.payload);
@@ -2488,14 +2522,13 @@ pub const DeclGen = struct {
         const payload_ty = self.air.typeOfIndex(inst);
 
         const err_ty_ref = try self.resolveType(Type.anyerror, .direct);
-        const payload_ty_ref = try self.resolveType(payload_ty, .direct);
         const bool_ty_ref = try self.resolveType(Type.bool, .direct);
 
         const eu_layout = self.errorUnionLayout(payload_ty);
 
         if (!err_union_ty.errorUnionSet().errorSetIsEmpty()) {
             const err_id = if (eu_layout.payload_has_bits)
-                try self.extractField(err_ty_ref, err_union_id, eu_layout.errorFieldIndex())
+                try self.extractField(Type.anyerror, err_union_id, eu_layout.errorFieldIndex())
             else
                 err_union_id;
 
@@ -2535,7 +2568,7 @@ pub const DeclGen = struct {
             return null;
         }
 
-        return try self.extractField(payload_ty_ref, err_union_id, eu_layout.payloadFieldIndex());
+        return try self.extractField(payload_ty, err_union_id, eu_layout.payloadFieldIndex());
     }
 
     fn airErrUnionErr(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -2559,7 +2592,7 @@ pub const DeclGen = struct {
             return operand_id;
         }
 
-        return try self.extractField(err_ty_ref, operand_id, eu_layout.errorFieldIndex());
+        return try self.extractField(Type.anyerror, operand_id, eu_layout.errorFieldIndex());
     }
 
     fn airWrapErrUnionErr(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -2598,6 +2631,69 @@ pub const DeclGen = struct {
         return result_id;
     }
 
+    fn airIsNull(self: *DeclGen, inst: Air.Inst.Index, pred: enum { is_null, is_non_null }) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const un_op = self.air.instructions.items(.data)[inst].un_op;
+        const operand_id = try self.resolve(un_op);
+        const optional_ty = self.air.typeOf(un_op);
+
+        var buf: Type.Payload.ElemType = undefined;
+        const payload_ty = optional_ty.optionalChild(&buf);
+
+        const bool_ty_ref = try self.resolveType(Type.bool, .direct);
+
+        if (optional_ty.optionalReprIsPayload()) {
+            // Pointer payload represents nullability: pointer or slice.
+
+            var ptr_buf: Type.SlicePtrFieldTypeBuffer = undefined;
+            const ptr_ty = if (payload_ty.isSlice())
+                payload_ty.slicePtrFieldType(&ptr_buf)
+            else
+                payload_ty;
+
+            const ptr_id = if (payload_ty.isSlice())
+                try self.extractField(Type.bool, operand_id, 0)
+            else
+                operand_id;
+
+            const payload_ty_ref = try self.resolveType(ptr_ty, .direct);
+            const null_id = try self.constNull(payload_ty_ref);
+            const result_id = self.spv.allocId();
+            const operands = .{
+                .id_result_type = self.typeId(bool_ty_ref),
+                .id_result = result_id,
+                .operand_1 = ptr_id,
+                .operand_2 = null_id,
+            };
+            switch (pred) {
+                .is_null => try self.func.body.emit(self.spv.gpa, .OpPtrEqual, operands),
+                .is_non_null => try self.func.body.emit(self.spv.gpa, .OpPtrNotEqual, operands),
+            }
+            return result_id;
+        }
+
+        const is_non_null_id = if (optional_ty.hasRuntimeBitsIgnoreComptime())
+            try self.extractField(Type.bool, operand_id, 1)
+        else
+            // Optional representation is bool indicating whether the optional is set
+            operand_id;
+
+        return switch (pred) {
+            .is_null => blk: {
+                // Invert condition
+                const result_id = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpLogicalNot, .{
+                    .id_result_type = self.typeId(bool_ty_ref),
+                    .id_result = result_id,
+                    .operand = is_non_null_id,
+                });
+                break :blk result_id;
+            },
+            .is_non_null => is_non_null_id,
+        };
+    }
+
     fn airSwitchBr(self: *DeclGen, inst: Air.Inst.Index) !void {
         const target = self.getTarget();
         const pl_op = self.air.instructions.items(.data)[inst].pl_op;