Commit 0bae2caaf3

Robin Voetter <robin@voetter.nl>
2023-04-10 20:34:15
spirv: lower air try
Implements code generation for the try air tag. This commit also adds a utility `errorUnionLayout` function that helps keeping the layout of a spir-v error union consistent.
1 parent dfecf89
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -765,21 +765,18 @@ pub const DeclGen = struct {
                     const is_pl = val.errorUnionIsPayload();
                     const error_val = if (!is_pl) val else Value.initTag(.zero);
 
-                    if (!payload_ty.hasRuntimeBitsIgnoreComptime()) {
+                    const eu_layout = dg.errorUnionLayout(payload_ty);
+                    if (!eu_layout.payload_has_bits) {
                         return try self.lower(Type.anyerror, error_val);
                     }
 
-                    const payload_align = payload_ty.abiAlignment(target);
-                    const error_align = Type.anyerror.abiAlignment(target);
-
                     const payload_size = payload_ty.abiSize(target);
                     const error_size = Type.anyerror.abiAlignment(target);
                     const ty_size = ty.abiSize(target);
                     const padding = ty_size - payload_size - error_size;
-
                     const payload_val = if (val.castTag(.eu_payload)) |pl| pl.data else Value.initTag(.undef);
 
-                    if (error_align > payload_align) {
+                    if (eu_layout.error_first) {
                         try self.lower(Type.anyerror, error_val);
                         try self.lower(payload_ty, payload_val);
                     } else {
@@ -1277,18 +1274,16 @@ pub const DeclGen = struct {
             .ErrorUnion => {
                 const payload_ty = ty.errorUnionPayload();
                 const error_ty_ref = try self.resolveType(Type.anyerror, .indirect);
-                if (!payload_ty.hasRuntimeBitsIgnoreComptime()) {
+
+                const eu_layout = self.errorUnionLayout(payload_ty);
+                if (!eu_layout.payload_has_bits) {
                     return error_ty_ref;
                 }
 
                 const payload_ty_ref = try self.resolveType(payload_ty, .indirect);
 
-                const payload_align = payload_ty.abiAlignment(target);
-                const error_align = Type.anyerror.abiAlignment(target);
-
                 var members = std.BoundedArray(SpvType.Payload.Struct.Member, 2){};
-                // Similar to unions, we're going to put the most aligned member first.
-                if (error_align > payload_align) {
+                if (eu_layout.error_first) {
                     // Put the error first
                     members.appendAssumeCapacity(.{ .ty = error_ty_ref, .name = "error" });
                     members.appendAssumeCapacity(.{ .ty = payload_ty_ref, .name = "payload" });
@@ -1336,6 +1331,34 @@ pub const DeclGen = struct {
         };
     }
 
+    const ErrorUnionLayout = struct {
+        payload_has_bits: bool,
+        error_first: bool,
+
+        fn errorFieldIndex(self: @This()) u32 {
+            assert(self.payload_has_bits);
+            return if (self.error_first) 0 else 1;
+        }
+
+        fn payloadFieldIndex(self: @This()) u32 {
+            assert(self.payload_has_bits);
+            return if (self.error_first) 1 else 0;
+        }
+    };
+
+    fn errorUnionLayout(self: *DeclGen, payload_ty: Type) ErrorUnionLayout {
+        const target = self.getTarget();
+
+        const error_align = Type.anyerror.abiAlignment(target);
+        const payload_align = payload_ty.abiAlignment(target);
+
+        const error_first = error_align > payload_align;
+        return .{
+            .payload_has_bits = payload_ty.hasRuntimeBitsIgnoreComptime(),
+            .error_first = error_first,
+        };
+    }
+
     /// The SPIR-V backend is not yet advanced enough to support the std testing infrastructure.
     /// In order to be able to run tests, we "temporarily" lower test kernels into separate entry-
     /// points. The test executor will then be able to invoke these to run the tests.
@@ -1585,6 +1608,7 @@ pub const DeclGen = struct {
             .loop       => return self.airLoop(inst),
             .ret        => return self.airRet(inst),
             .ret_load   => return self.airRetLoad(inst),
+            .@"try"     => try self.airTry(inst),
             .switch_br  => return self.airSwitchBr(inst),
             .unreach    => return self.airUnreach(),
 
@@ -1752,16 +1776,15 @@ pub const DeclGen = struct {
         const operand_ty_id = try self.resolveTypeId(operand_ty);
         const result_type_id = try self.resolveTypeId(result_ty);
 
-        const overflow_member_ty = try self.intType(.unsigned, info.bits);
-        const overflow_member_ty_id = self.typeId(overflow_member_ty);
+        const overflow_member_ty_ref = try self.intType(.unsigned, info.bits);
 
         const op_result_id = blk: {
             // Construct the SPIR-V result type.
             // It is almost the same as the zig one, except that the fields must be the same type
             // and they must be unsigned.
             const overflow_result_ty_ref = try self.spv.simpleStructType(&.{
-                .{ .ty = overflow_member_ty, .name = "res" },
-                .{ .ty = overflow_member_ty, .name = "ov" },
+                .{ .ty = overflow_member_ty_ref, .name = "res" },
+                .{ .ty = overflow_member_ty_ref, .name = "ov" },
             });
             const result_id = self.spv.allocId();
             try self.func.body.emit(self.spv.gpa, .OpIAddCarry, .{
@@ -1775,8 +1798,8 @@ pub const DeclGen = struct {
 
         // Now convert the SPIR-V flavor result into a Zig-flavor result.
         // First, extract the two fields.
-        const unsigned_result = try self.extractField(overflow_member_ty_id, op_result_id, 0);
-        const overflow = try self.extractField(overflow_member_ty_id, op_result_id, 1);
+        const unsigned_result = try self.extractField(overflow_member_ty_ref, op_result_id, 0);
+        const overflow = try self.extractField(overflow_member_ty_ref, op_result_id, 1);
 
         // We need to convert the results to the types that Zig expects here.
         // The `result` is the same type except unsigned, so we can just bitcast that.
@@ -1954,15 +1977,16 @@ pub const DeclGen = struct {
         return result_id;
     }
 
-    fn extractField(self: *DeclGen, result_ty: IdResultType, object: IdRef, field: u32) !IdRef {
+    fn extractField(self: *DeclGen, result_ty_ref: SpvType.Ref, object: IdRef, field: u32) !IdRef {
         const result_id = self.spv.allocId();
         const indexes = [_]u32{field};
         try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{
-            .id_result_type = result_ty,
+            .id_result_type = self.typeId(result_ty_ref),
             .id_result = result_id,
             .composite = object,
             .indexes = &indexes,
         });
+        // TODO: Convert bools, direct structs should have their field types as indirect values.
         return result_id;
     }
 
@@ -1970,7 +1994,7 @@ pub const DeclGen = struct {
         if (self.liveness.isUnused(inst)) return null;
         const ty_op = self.air.instructions.items(.data)[inst].ty_op;
         return try self.extractField(
-            try self.resolveTypeId(self.air.typeOfIndex(inst)),
+            try self.resolveType(self.air.typeOfIndex(inst), .direct),
             try self.resolve(ty_op.operand),
             field,
         );
@@ -2451,6 +2475,66 @@ pub const DeclGen = struct {
         });
     }
 
+    fn airTry(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        const pl_op = self.air.instructions.items(.data)[inst].pl_op;
+        const err_union_id = try self.resolve(pl_op.operand);
+        const extra = self.air.extraData(Air.Try, pl_op.payload);
+        const body = self.air.extra[extra.end..][0..extra.data.body_len];
+
+        const err_union_ty = self.air.typeOf(pl_op.operand);
+        const payload_ty = self.air.typeOfIndex(inst);
+
+        const err_ty_ref = try self.resolveType(Type.anyerror, .direct);
+        const payload_ty_ref = try self.resolveType(payload_ty, .direct);
+        const bool_ty_ref = try self.resolveType(Type.bool, .direct);
+
+        const eu_layout = self.errorUnionLayout(payload_ty);
+
+        if (!err_union_ty.errorUnionSet().errorSetIsEmpty()) {
+            const err_id = if (eu_layout.payload_has_bits)
+                try self.extractField(err_ty_ref, err_union_id, eu_layout.errorFieldIndex())
+            else
+                err_union_id;
+
+            const zero_id = try self.constInt(err_ty_ref, 0);
+            const is_err_id = self.spv.allocId();
+            try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
+                .id_result_type = self.typeId(bool_ty_ref),
+                .id_result = is_err_id,
+                .operand_1 = err_id,
+                .operand_2 = zero_id,
+            });
+
+            // When there is an error, we must evaluate `body`. Otherwise we must continue
+            // with the current body.
+            // Just generate a new block here, then generate a new block inline for the remainder of the body.
+
+            const err_block = self.spv.allocId();
+            const ok_block = self.spv.allocId();
+
+            // TODO: Merge block
+            try self.func.body.emit(self.spv.gpa, .OpBranchConditional, .{
+                .condition = is_err_id,
+                .true_label = err_block,
+                .false_label = ok_block,
+            });
+
+            try self.beginSpvBlock(err_block);
+            try self.genBody(body);
+
+            try self.beginSpvBlock(ok_block);
+            // Now just extract the payload, if required.
+        }
+        if (self.liveness.isUnused(inst)) {
+            return null;
+        }
+        if (!eu_layout.payload_has_bits) {
+            return null;
+        }
+
+        return try self.extractField(payload_ty_ref, err_union_id, eu_layout.payloadFieldIndex());
+    }
+
     fn airSwitchBr(self: *DeclGen, inst: Air.Inst.Index) !void {
         const target = self.getTarget();
         const pl_op = self.air.instructions.items(.data)[inst].pl_op;