Commit 004d0c8978

David Rubin <daviru007@icloud.com>
2024-04-19 21:40:24
riscv: switch progress + by-ref return progress
1 parent 4aa1544
src/arch/riscv64/CodeGen.zig
@@ -1223,7 +1223,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
 
             .field_parent_ptr => try self.airFieldParentPtr(inst),
 
-            .switch_br       => try self.airSwitch(inst),
+            .switch_br       => try self.airSwitchBr(inst),
             .slice_ptr       => try self.airSlicePtr(inst),
             .slice_len       => try self.airSliceLen(inst),
 
@@ -1960,7 +1960,7 @@ fn binOp(
             switch (lhs_ty.zigTypeTag(zcu)) {
                 .Float => return self.fail("TODO binary operations on floats", .{}),
                 .Vector => return self.fail("TODO binary operations on vectors", .{}),
-                .Int => {
+                .Int, .Enum => {
                     assert(lhs_ty.eql(rhs_ty, zcu));
                     const int_info = lhs_ty.intInfo(zcu);
                     if (int_info.bits <= 64) {
@@ -3682,7 +3682,6 @@ fn airRetLoad(self: *Self, inst: Air.Inst.Index) !void {
     switch (self.ret_mcv.short) {
         .none => {},
         .register, .register_pair => try self.load(self.ret_mcv.short, ptr, ptr_ty),
-        .indirect => |reg_off| try self.genSetReg(ptr_ty, reg_off.reg, ptr),
         else => unreachable,
     }
     self.ret_mcv.liveOut(self, inst);
@@ -4160,12 +4159,97 @@ fn lowerBlock(self: *Self, inst: Air.Inst.Index, body: []const Air.Inst.Index) !
     self.finishAirBookkeeping();
 }
 
-fn airSwitch(self: *Self, inst: Air.Inst.Index) !void {
+fn airSwitchBr(self: *Self, inst: Air.Inst.Index) !void {
     const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op;
-    const condition = pl_op.operand;
-    _ = condition;
-    return self.fail("TODO airSwitch for {}", .{self.target.cpu.arch});
-    // return self.finishAir(inst, .dead, .{ condition, .none, .none });
+    const condition = try self.resolveInst(pl_op.operand);
+    const condition_ty = self.typeOf(pl_op.operand);
+    const switch_br = self.air.extraData(Air.SwitchBr, pl_op.payload);
+    var extra_index: usize = switch_br.end;
+    var case_i: u32 = 0;
+    const liveness = try self.liveness.getSwitchBr(self.gpa, inst, switch_br.data.cases_len + 1);
+    defer self.gpa.free(liveness.deaths);
+
+    // If the condition dies here in this switch instruction, process
+    // that death now instead of later as this has an effect on
+    // whether it needs to be spilled in the branches
+    if (self.liveness.operandDies(inst, 0)) {
+        if (pl_op.operand.toIndex()) |op_inst| try self.processDeath(op_inst);
+    }
+
+    self.scope_generation += 1;
+    const state = try self.saveState();
+
+    while (case_i < switch_br.data.cases_len) : (case_i += 1) {
+        const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
+        const items: []const Air.Inst.Ref =
+            @ptrCast(self.air.extra[case.end..][0..case.data.items_len]);
+        const case_body: []const Air.Inst.Index =
+            @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]);
+        extra_index = case.end + items.len + case_body.len;
+
+        var relocs = try self.gpa.alloc(Mir.Inst.Index, items.len);
+        defer self.gpa.free(relocs);
+
+        for (items, relocs, 0..) |item, *reloc, i| {
+            // switch branches must be comptime-known, so this is stored in an immediate
+            const item_mcv = try self.resolveInst(item);
+
+            const cmp_mcv: MCValue = try self.binOp(
+                .cmp_neq,
+                condition,
+                condition_ty,
+                item_mcv,
+                condition_ty,
+            );
+
+            const cmp_reg = try self.copyToTmpRegister(Type.bool, cmp_mcv);
+
+            if (!(i < relocs.len - 1)) {
+                _ = try self.addInst(.{
+                    .tag = .pseudo,
+                    .ops = .pseudo_not,
+                    .data = .{ .rr = .{
+                        .rd = cmp_reg,
+                        .rs = cmp_reg,
+                    } },
+                });
+            }
+
+            reloc.* = try self.condBr(condition_ty, .{ .register = cmp_reg });
+        }
+
+        for (liveness.deaths[case_i]) |operand| try self.processDeath(operand);
+
+        for (relocs[0 .. relocs.len - 1]) |reloc| self.performReloc(reloc);
+        try self.genBody(case_body);
+        try self.restoreState(state, &.{}, .{
+            .emit_instructions = false,
+            .update_tracking = true,
+            .resurrect = true,
+            .close_scope = true,
+        });
+
+        self.performReloc(relocs[relocs.len - 1]);
+    }
+
+    if (switch_br.data.else_body_len > 0) {
+        const else_body: []const Air.Inst.Index =
+            @ptrCast(self.air.extra[extra_index..][0..switch_br.data.else_body_len]);
+
+        const else_deaths = liveness.deaths.len - 1;
+        for (liveness.deaths[else_deaths]) |operand| try self.processDeath(operand);
+
+        try self.genBody(else_body);
+        try self.restoreState(state, &.{}, .{
+            .emit_instructions = false,
+            .update_tracking = true,
+            .resurrect = true,
+            .close_scope = true,
+        });
+    }
+
+    // We already took care of pl_op.operand earlier, so there's nothing left to do
+    self.finishAirBookkeeping();
 }
 
 fn performReloc(self: *Self, inst: Mir.Inst.Index) void {
@@ -4249,9 +4333,60 @@ fn airBr(self: *Self, inst: Air.Inst.Index) !void {
 
 fn airBoolOp(self: *Self, inst: Air.Inst.Index) !void {
     const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
-    const air_tags = self.air.instructions.items(.tag);
-    _ = air_tags;
-    const result: MCValue = if (self.liveness.isUnused(inst)) .unreach else return self.fail("TODO implement boolean operations for {}", .{self.target.cpu.arch});
+    const tag: Air.Inst.Tag = self.air.instructions.items(.tag)[@intFromEnum(inst)];
+
+    const result: MCValue = if (self.liveness.isUnused(inst)) .unreach else result: {
+        const lhs = try self.resolveInst(bin_op.lhs);
+        const rhs = try self.resolveInst(bin_op.rhs);
+        const lhs_ty = Type.bool;
+        const rhs_ty = Type.bool;
+
+        const lhs_reg, const lhs_lock = blk: {
+            if (lhs == .register) break :blk .{ lhs.register, null };
+
+            const lhs_reg, const lhs_lock = try self.allocReg();
+            try self.genSetReg(lhs_ty, lhs_reg, lhs);
+            break :blk .{ lhs_reg, lhs_lock };
+        };
+        defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
+
+        const rhs_reg, const rhs_lock = blk: {
+            if (rhs == .register) break :blk .{ rhs.register, null };
+
+            const rhs_reg, const rhs_lock = try self.allocReg();
+            try self.genSetReg(rhs_ty, rhs_reg, rhs);
+            break :blk .{ rhs_reg, rhs_lock };
+        };
+        defer if (rhs_lock) |lock| self.register_manager.unlockReg(lock);
+
+        const result_reg, const result_lock = try self.allocReg();
+        defer self.register_manager.unlockReg(result_lock);
+
+        _ = try self.addInst(.{
+            .tag = if (tag == .bool_or) .@"or" else .@"and",
+            .ops = .rrr,
+            .data = .{ .r_type = .{
+                .rd = result_reg,
+                .rs1 = lhs_reg,
+                .rs2 = rhs_reg,
+            } },
+        });
+
+        // safety truncate
+        if (self.wantSafety()) {
+            _ = try self.addInst(.{
+                .tag = .andi,
+                .ops = .rri,
+                .data = .{ .i_type = .{
+                    .rd = result_reg,
+                    .rs1 = result_reg,
+                    .imm12 = Immediate.s(1),
+                } },
+            });
+        }
+
+        break :result .{ .register = result_reg };
+    };
     return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
 }
 
@@ -5265,7 +5400,9 @@ fn resolveCallingConventionValues(
                     },
                     .memory => {
                         const param_int_regs = abi.function_arg_regs;
+
                         const param_int_reg = param_int_regs[param_int_reg_i];
+                        param_int_reg_i += 1;
 
                         arg_mcv[arg_mcv_i] = .{ .indirect = .{ .reg = param_int_reg } };
                         arg_mcv_i += 1;
src/arch/riscv64/Encoding.zig
@@ -38,6 +38,7 @@ pub const Mnemonic = enum {
     // R Type
     add,
     @"and",
+    @"or",
     sub,
     slt,
     mul,
@@ -55,6 +56,7 @@ pub const Mnemonic = enum {
             .add    => .{ .opcode = 0b0110011, .funct3 = 0b000, .funct7 = 0b0000000 },
             .sltu   => .{ .opcode = 0b0110011, .funct3 = 0b011, .funct7 = 0b0000000 },
             .@"and" => .{ .opcode = 0b0110011, .funct3 = 0b111, .funct7 = 0b0000000 },
+            .@"or"  => .{ .opcode = 0b0110011, .funct3 = 0b110, .funct7 = 0b0000000 },
             .sub    => .{ .opcode = 0b0110011, .funct3 = 0b000, .funct7 = 0b0100000 }, 
 
             .ld     => .{ .opcode = 0b0000011, .funct3 = 0b011, .funct7 = null      },
@@ -152,6 +154,7 @@ pub const InstEnc = enum {
             .add,
             .sub,
             .@"and",
+            .@"or",
             => .R,
 
             .ecall,
src/arch/riscv64/Mir.zig
@@ -80,9 +80,6 @@ pub const Inst = struct {
         /// Branch if not equal, Uses b_type
         bne,
 
-        /// Boolean NOT, Uses rr payload
-        not,
-
         /// Generates a NO-OP, uses nop payload
         nop,
 
test/behavior/align.zig
@@ -624,7 +624,6 @@ test "alignment of slice element" {
 }
 
 test "sub-aligned pointer field access" {
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest;
test/behavior/cast.zig
@@ -881,7 +881,6 @@ test "peer resolution of string literals" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const S = struct {
         const E = enum { a, b, c, d };
test/behavior/enum.zig
@@ -610,7 +610,6 @@ fn testEnumWithSpecifiedTagValues(x: MultipleChoice) !void {
 test "enum with specified tag values" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     try testEnumWithSpecifiedTagValues(MultipleChoice.C);
     try comptime testEnumWithSpecifiedTagValues(MultipleChoice.C);
@@ -749,7 +748,6 @@ test "cast integer literal to enum" {
 test "enum with specified and unspecified tag values" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     try testEnumWithSpecifiedAndUnspecifiedTagValues(MultipleChoice2.D);
     try comptime testEnumWithSpecifiedAndUnspecifiedTagValues(MultipleChoice2.D);
test/behavior/eval.zig
@@ -1088,7 +1088,6 @@ test "comptime break operand passing through runtime condition converted to runt
 test "comptime break operand passing through runtime switch converted to runtime break" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest(runtime: u8) !void {
@@ -1631,8 +1630,6 @@ test "struct in comptime false branch is not evaluated" {
 }
 
 test "result of nested switch assigned to variable" {
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
-
     var zds: u32 = 0;
     zds = switch (zds) {
         0 => switch (zds) {
@@ -1667,8 +1664,6 @@ test "inline for loop of functions returning error unions" {
 }
 
 test "if inside a switch" {
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
-
     var condition = true;
     var wave_type: u32 = 0;
     _ = .{ &condition, &wave_type };
test/behavior/inline_switch.zig
@@ -5,7 +5,6 @@ const builtin = @import("builtin");
 test "inline scalar prongs" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     var x: usize = 0;
     switch (x) {
@@ -21,7 +20,6 @@ test "inline scalar prongs" {
 test "inline prong ranges" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     var x: usize = 0;
     _ = &x;
@@ -37,7 +35,6 @@ const E = enum { a, b, c, d };
 test "inline switch enums" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     var x: E = .a;
     _ = &x;
@@ -106,7 +103,6 @@ test "inline else error" {
 test "inline else enum" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const E2 = enum(u8) { a = 2, b = 3, c = 4, d = 5 };
     var a: E2 = .a;
@@ -120,7 +116,6 @@ test "inline else enum" {
 test "inline else int with gaps" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     var a: u8 = 0;
     _ = &a;
@@ -139,7 +134,6 @@ test "inline else int with gaps" {
 test "inline else int all values" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     var a: u2 = 0;
     _ = &a;
test/behavior/ref_var_in_if_after_if_2nd_switch_prong.zig
@@ -8,7 +8,6 @@ test "reference a variable in an if after an if in the 2nd switch prong" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     try foo(true, Num.Two, false, "aoeu");
     try expect(!ok);
test/behavior/switch.zig
@@ -7,7 +7,6 @@ const expectEqual = std.testing.expectEqual;
 
 test "switch with numbers" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     try testSwitchWithNumbers(13);
 }
@@ -23,7 +22,6 @@ fn testSwitchWithNumbers(x: u32) !void {
 
 test "switch with all ranges" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     try expect(testSwitchWithAllRanges(50, 3) == 1);
     try expect(testSwitchWithAllRanges(101, 0) == 2);
@@ -57,27 +55,25 @@ test "implicit comptime switch" {
 
 test "switch on enum" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const fruit = Fruit.Orange;
-    nonConstSwitchOnEnum(fruit);
+    try expect(nonConstSwitchOnEnum(fruit));
 }
 const Fruit = enum {
     Apple,
     Orange,
     Banana,
 };
-fn nonConstSwitchOnEnum(fruit: Fruit) void {
-    switch (fruit) {
-        Fruit.Apple => unreachable,
-        Fruit.Orange => {},
-        Fruit.Banana => unreachable,
-    }
+fn nonConstSwitchOnEnum(fruit: Fruit) bool {
+    return switch (fruit) {
+        Fruit.Apple => false,
+        Fruit.Orange => true,
+        Fruit.Banana => false,
+    };
 }
 
 test "switch statement" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     try nonConstSwitch(SwitchStatementFoo.C);
 }
@@ -94,7 +90,6 @@ const SwitchStatementFoo = enum { A, B, C, D };
 
 test "switch with multiple expressions" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const x = switch (returnsFive()) {
         1, 2, 3 => 1,
@@ -179,7 +174,6 @@ test "undefined.u0" {
 
 test "switch with disjoint range" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     var q: u8 = 0;
     _ = &q;
@@ -191,8 +185,6 @@ test "switch with disjoint range" {
 }
 
 test "switch variable for range and multiple prongs" {
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
-
     const S = struct {
         fn doTheTest() !void {
             try doTheSwitch(16);
@@ -382,7 +374,6 @@ test "anon enum literal used in switch on union enum" {
 
 test "switch all prongs unreachable" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     try testAllProngsUnreachable();
     try comptime testAllProngsUnreachable();
@@ -420,7 +411,6 @@ fn return_a_number() anyerror!i32 {
 
 test "switch on integer with else capturing expr" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -735,7 +725,6 @@ test "switch capture copies its payload" {
 
 test "capture of integer forwards the switch condition directly" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const S = struct {
         fn foo(x: u8) !void {
@@ -757,7 +746,6 @@ test "capture of integer forwards the switch condition directly" {
 
 test "enum value without tag name used as switch item" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const E = enum(u32) {
         a = 1,
@@ -775,8 +763,6 @@ test "enum value without tag name used as switch item" {
 }
 
 test "switch item sizeof" {
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
-
     const S = struct {
         fn doTheTest() !void {
             var a: usize = 0;
@@ -873,8 +859,6 @@ test "switch pointer capture peer type resolution" {
 }
 
 test "inline switch range that includes the maximum value of the switched type" {
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
-
     const inputs: [3]u8 = .{ 0, 254, 255 };
     for (inputs) |input| {
         switch (input) {
@@ -970,8 +954,6 @@ test "prong with inline call to unreachable" {
 }
 
 test "block error return trace index is reset between prongs" {
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
-
     const S = struct {
         fn returnError() error{TestFailed} {
             return error.TestFailed;