Commit 1caf56c5fb

joachimschmidt557 <joachim.schmidt557@outlook.com>
2022-12-29 05:21:31
stage2 AArch64: implement errUnion{Err,Payload} for registers
1 parent 34887cf
Changed files (6)
src/arch/aarch64/CodeGen.zig
@@ -3050,19 +3050,60 @@ fn airOptionalPayloadPtrSet(self: *Self, inst: Air.Inst.Index) !void {
 }
 
 /// Given an error union, returns the error
-fn errUnionErr(self: *Self, error_union_mcv: MCValue, error_union_ty: Type) !MCValue {
+fn errUnionErr(
+    self: *Self,
+    error_union_bind: ReadArg.Bind,
+    error_union_ty: Type,
+    maybe_inst: ?Air.Inst.Index,
+) !MCValue {
     const err_ty = error_union_ty.errorUnionSet();
     const payload_ty = error_union_ty.errorUnionPayload();
     if (err_ty.errorSetIsEmpty()) {
         return MCValue{ .immediate = 0 };
     }
     if (!payload_ty.hasRuntimeBitsIgnoreComptime()) {
-        return error_union_mcv;
+        return try error_union_bind.resolveToMcv(self);
     }
 
     const err_offset = @intCast(u32, errUnionErrorOffset(payload_ty, self.target.*));
-    switch (error_union_mcv) {
-        .register => return self.fail("TODO errUnionErr for registers", .{}),
+    switch (try error_union_bind.resolveToMcv(self)) {
+        .register => {
+            var operand_reg: Register = undefined;
+            var dest_reg: Register = undefined;
+
+            const read_args = [_]ReadArg{
+                .{ .ty = error_union_ty, .bind = error_union_bind, .class = gp, .reg = &operand_reg },
+            };
+            const write_args = [_]WriteArg{
+                .{ .ty = err_ty, .bind = .none, .class = gp, .reg = &dest_reg },
+            };
+            try self.allocRegs(
+                &read_args,
+                &write_args,
+                if (maybe_inst) |inst| .{
+                    .corresponding_inst = inst,
+                    .operand_mapping = &.{0},
+                } else null,
+            );
+
+            const err_bit_offset = err_offset * 8;
+            const err_bit_size = @intCast(u32, err_ty.abiSize(self.target.*)) * 8;
+
+            _ = try self.addInst(.{
+                .tag = .ubfx, // errors are unsigned integers
+                .data = .{
+                    .rr_lsb_width = .{
+                        // Set both registers to the X variant to get the full width
+                        .rd = dest_reg.toX(),
+                        .rn = operand_reg.toX(),
+                        .lsb = @intCast(u6, err_bit_offset),
+                        .width = @intCast(u7, err_bit_size),
+                    },
+                },
+            });
+
+            return MCValue{ .register = dest_reg };
+        },
         .stack_argument_offset => |off| {
             return MCValue{ .stack_argument_offset = off + err_offset };
         },
@@ -3079,27 +3120,69 @@ fn errUnionErr(self: *Self, error_union_mcv: MCValue, error_union_ty: Type) !MCV
 fn airUnwrapErrErr(self: *Self, inst: Air.Inst.Index) !void {
     const ty_op = self.air.instructions.items(.data)[inst].ty_op;
     const result: MCValue = if (self.liveness.isUnused(inst)) .dead else result: {
+        const error_union_bind: ReadArg.Bind = .{ .inst = ty_op.operand };
         const error_union_ty = self.air.typeOf(ty_op.operand);
-        const mcv = try self.resolveInst(ty_op.operand);
-        break :result try self.errUnionErr(mcv, error_union_ty);
+
+        break :result try self.errUnionErr(error_union_bind, error_union_ty, inst);
     };
     return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
 
 /// Given an error union, returns the payload
-fn errUnionPayload(self: *Self, error_union_mcv: MCValue, error_union_ty: Type) !MCValue {
+fn errUnionPayload(
+    self: *Self,
+    error_union_bind: ReadArg.Bind,
+    error_union_ty: Type,
+    maybe_inst: ?Air.Inst.Index,
+) !MCValue {
     const err_ty = error_union_ty.errorUnionSet();
     const payload_ty = error_union_ty.errorUnionPayload();
     if (err_ty.errorSetIsEmpty()) {
-        return error_union_mcv;
+        return try error_union_bind.resolveToMcv(self);
     }
     if (!payload_ty.hasRuntimeBitsIgnoreComptime()) {
         return MCValue.none;
     }
 
     const payload_offset = @intCast(u32, errUnionPayloadOffset(payload_ty, self.target.*));
-    switch (error_union_mcv) {
-        .register => return self.fail("TODO errUnionPayload for registers", .{}),
+    switch (try error_union_bind.resolveToMcv(self)) {
+        .register => {
+            var operand_reg: Register = undefined;
+            var dest_reg: Register = undefined;
+
+            const read_args = [_]ReadArg{
+                .{ .ty = error_union_ty, .bind = error_union_bind, .class = gp, .reg = &operand_reg },
+            };
+            const write_args = [_]WriteArg{
+                .{ .ty = err_ty, .bind = .none, .class = gp, .reg = &dest_reg },
+            };
+            try self.allocRegs(
+                &read_args,
+                &write_args,
+                if (maybe_inst) |inst| .{
+                    .corresponding_inst = inst,
+                    .operand_mapping = &.{0},
+                } else null,
+            );
+
+            const payload_bit_offset = payload_offset * 8;
+            const payload_bit_size = @intCast(u32, payload_ty.abiSize(self.target.*)) * 8;
+
+            _ = try self.addInst(.{
+                .tag = if (payload_ty.isSignedInt()) Mir.Inst.Tag.sbfx else .ubfx,
+                .data = .{
+                    .rr_lsb_width = .{
+                        // Set both registers to the X variant to get the full width
+                        .rd = dest_reg.toX(),
+                        .rn = operand_reg.toX(),
+                        .lsb = @intCast(u5, payload_bit_offset),
+                        .width = @intCast(u6, payload_bit_size),
+                    },
+                },
+            });
+
+            return MCValue{ .register = dest_reg };
+        },
         .stack_argument_offset => |off| {
             return MCValue{ .stack_argument_offset = off + payload_offset };
         },
@@ -3116,9 +3199,10 @@ fn errUnionPayload(self: *Self, error_union_mcv: MCValue, error_union_ty: Type)
 fn airUnwrapErrPayload(self: *Self, inst: Air.Inst.Index) !void {
     const ty_op = self.air.instructions.items(.data)[inst].ty_op;
     const result: MCValue = if (self.liveness.isUnused(inst)) .dead else result: {
+        const error_union_bind: ReadArg.Bind = .{ .inst = ty_op.operand };
         const error_union_ty = self.air.typeOf(ty_op.operand);
-        const error_union = try self.resolveInst(ty_op.operand);
-        break :result try self.errUnionPayload(error_union, error_union_ty);
+
+        break :result try self.errUnionPayload(error_union_bind, error_union_ty, inst);
     };
     return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
@@ -3399,9 +3483,14 @@ fn airArrayElemVal(self: *Self, inst: Air.Inst.Index) !void {
 }
 
 fn airPtrElemVal(self: *Self, inst: Air.Inst.Index) !void {
-    const is_volatile = false; // TODO
     const bin_op = self.air.instructions.items(.data)[inst].bin_op;
-    const result: MCValue = if (!is_volatile and self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement ptr_elem_val for {}", .{self.target.cpu.arch});
+    const ptr_ty = self.air.typeOf(bin_op.lhs);
+    const result: MCValue = if (!ptr_ty.isVolatilePtr() and self.liveness.isUnused(inst)) .dead else result: {
+        const base_bind: ReadArg.Bind = .{ .inst = bin_op.lhs };
+        const index_bind: ReadArg.Bind = .{ .inst = bin_op.rhs };
+
+        break :result try self.ptrElemVal(base_bind, index_bind, ptr_ty, inst);
+    };
     return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
 }
 
@@ -4792,19 +4881,27 @@ fn isNonNull(self: *Self, operand_bind: ReadArg.Bind, operand_ty: Type) !MCValue
     return MCValue{ .compare_flags = is_null_res.compare_flags.negate() };
 }
 
-fn isErr(self: *Self, ty: Type, operand: MCValue) !MCValue {
-    const error_type = ty.errorUnionSet();
+fn isErr(
+    self: *Self,
+    error_union_bind: ReadArg.Bind,
+    error_union_ty: Type,
+) !MCValue {
+    const error_type = error_union_ty.errorUnionSet();
 
     if (error_type.errorSetIsEmpty()) {
         return MCValue{ .immediate = 0 }; // always false
     }
 
-    const error_mcv = try self.errUnionErr(operand, ty);
+    const error_mcv = try self.errUnionErr(error_union_bind, error_union_ty, null);
     return try self.cmp(.{ .mcv = error_mcv }, .{ .mcv = .{ .immediate = 0 } }, error_type, .gt);
 }
 
-fn isNonErr(self: *Self, ty: Type, operand: MCValue) !MCValue {
-    const is_err_result = try self.isErr(ty, operand);
+fn isNonErr(
+    self: *Self,
+    error_union_bind: ReadArg.Bind,
+    error_union_ty: Type,
+) !MCValue {
+    const is_err_result = try self.isErr(error_union_bind, error_union_ty);
     switch (is_err_result) {
         .compare_flags => |cond| {
             assert(cond == .hi);
@@ -4873,9 +4970,10 @@ fn airIsNonNullPtr(self: *Self, inst: Air.Inst.Index) !void {
 fn airIsErr(self: *Self, inst: Air.Inst.Index) !void {
     const un_op = self.air.instructions.items(.data)[inst].un_op;
     const result: MCValue = if (self.liveness.isUnused(inst)) .dead else result: {
-        const operand = try self.resolveInst(un_op);
-        const ty = self.air.typeOf(un_op);
-        break :result try self.isErr(ty, operand);
+        const error_union_bind: ReadArg.Bind = .{ .inst = un_op };
+        const error_union_ty = self.air.typeOf(un_op);
+
+        break :result try self.isErr(error_union_bind, error_union_ty);
     };
     return self.finishAir(inst, result, .{ un_op, .none, .none });
 }
@@ -4890,7 +4988,7 @@ fn airIsErrPtr(self: *Self, inst: Air.Inst.Index) !void {
         const operand = try self.allocRegOrMem(elem_ty, true, null);
         try self.load(operand, operand_ptr, ptr_ty);
 
-        break :result try self.isErr(elem_ty, operand);
+        break :result try self.isErr(.{ .mcv = operand }, elem_ty);
     };
     return self.finishAir(inst, result, .{ un_op, .none, .none });
 }
@@ -4898,9 +4996,10 @@ fn airIsErrPtr(self: *Self, inst: Air.Inst.Index) !void {
 fn airIsNonErr(self: *Self, inst: Air.Inst.Index) !void {
     const un_op = self.air.instructions.items(.data)[inst].un_op;
     const result: MCValue = if (self.liveness.isUnused(inst)) .dead else result: {
-        const operand = try self.resolveInst(un_op);
-        const ty = self.air.typeOf(un_op);
-        break :result try self.isNonErr(ty, operand);
+        const error_union_bind: ReadArg.Bind = .{ .inst = un_op };
+        const error_union_ty = self.air.typeOf(un_op);
+
+        break :result try self.isNonErr(error_union_bind, error_union_ty);
     };
     return self.finishAir(inst, result, .{ un_op, .none, .none });
 }
@@ -4915,7 +5014,7 @@ fn airIsNonErrPtr(self: *Self, inst: Air.Inst.Index) !void {
         const operand = try self.allocRegOrMem(elem_ty, true, null);
         try self.load(operand, operand_ptr, ptr_ty);
 
-        break :result try self.isNonErr(elem_ty, operand);
+        break :result try self.isNonErr(.{ .mcv = operand }, elem_ty);
     };
     return self.finishAir(inst, result, .{ un_op, .none, .none });
 }
@@ -5960,15 +6059,24 @@ fn airTry(self: *Self, inst: Air.Inst.Index) !void {
     const extra = self.air.extraData(Air.Try, pl_op.payload);
     const body = self.air.extra[extra.end..][0..extra.data.body_len];
     const result: MCValue = result: {
+        const error_union_bind: ReadArg.Bind = .{ .inst = pl_op.operand };
         const error_union_ty = self.air.typeOf(pl_op.operand);
-        const error_union = try self.resolveInst(pl_op.operand);
-        const is_err_result = try self.isErr(error_union_ty, error_union);
+        const error_union_size = @intCast(u32, error_union_ty.abiSize(self.target.*));
+        const error_union_align = error_union_ty.abiAlignment(self.target.*);
+
+        // The error union will die in the body. However, we need the
+        // error union after the body in order to extract the payload
+        // of the error union, so we create a copy of it
+        const error_union_copy = try self.allocMem(error_union_size, error_union_align, null);
+        try self.genSetStack(error_union_ty, error_union_copy, try error_union_bind.resolveToMcv(self));
+
+        const is_err_result = try self.isErr(error_union_bind, error_union_ty);
         const reloc = try self.condBr(is_err_result);
 
         try self.genBody(body);
-
         try self.performReloc(reloc);
-        break :result try self.errUnionPayload(error_union, error_union_ty);
+
+        break :result try self.errUnionPayload(.{ .mcv = .{ .stack_offset = error_union_copy } }, error_union_ty, null);
     };
     return self.finishAir(inst, result, .{ pl_op.operand, .none, .none });
 }
test/behavior/error.zig
@@ -60,7 +60,6 @@ pub fn baz() anyerror!i32 {
 }
 
 test "error wrapping" {
-    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
     try expect((baz() catch unreachable) == 15);
@@ -100,7 +99,6 @@ test "syntax: optional operator in front of error union operator" {
 
 test "widen cast integer payload of error union function call" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
     const S = struct {
@@ -715,7 +713,6 @@ test "ret_ptr doesn't cause own inferred error set to be resolved" {
 }
 
 test "simple else prong allowed even when all errors handled" {
-    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
test/behavior/if.zig
@@ -44,7 +44,6 @@ var global_with_val: anyerror!u32 = 0;
 var global_with_err: anyerror!u32 = error.SomeError;
 
 test "unwrap mutable global var" {
-    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
     if (global_with_val) |v| {
test/behavior/switch.zig
@@ -390,7 +390,6 @@ fn switchWithUnreachable(x: i32) i32 {
 }
 
 test "capture value of switch with all unreachable prongs" {
-    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
     const x = return_a_number() catch |err| switch (err) {
@@ -494,7 +493,6 @@ test "switch prongs with error set cases make a new error set type for capture v
 }
 
 test "return result loc and then switch with range implicit casted to error union" {
-    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
     const S = struct {
test/behavior/try.zig
@@ -3,7 +3,6 @@ const builtin = @import("builtin");
 const expect = std.testing.expect;
 
 test "try on error union" {
-    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
     try tryOnErrorUnionImpl();
test/behavior/while.zig
@@ -175,7 +175,6 @@ test "while with optional as condition with else" {
 
 test "while with error union condition" {
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
     numbers_left = 10;
@@ -258,7 +257,6 @@ fn returnWithImplicitCastFromWhileLoopTest() anyerror!void {
 }
 
 test "while on error union with else result follow else prong" {
-    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
     const result = while (returnError()) |value| {
@@ -268,7 +266,6 @@ test "while on error union with else result follow else prong" {
 }
 
 test "while on error union with else result follow break prong" {
-    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
     const result = while (returnSuccess(10)) |value| {