Commit 846bd40361

David Rubin <daviru007@icloud.com>
2024-07-26 21:43:47
riscv: implement `@cmpxchg*` and remove fixes
1 parent 9752bbf
Changed files (3)
src
arch
test
behavior
src/arch/riscv64/CodeGen.zig
@@ -1287,13 +1287,11 @@ fn gen(func: *Func) !void {
         // ret
         _ = try func.addInst(.{
             .tag = .jalr,
-            .data = .{
-                .i_type = .{
-                    .rd = .zero,
-                    .rs1 = .ra,
-                    .imm12 = Immediate.s(0),
-                },
-            },
+            .data = .{ .i_type = .{
+                .rd = .zero,
+                .rs1 = .ra,
+                .imm12 = Immediate.s(0),
+            } },
         });
 
         const frame_layout = try func.computeFrameLayout();
@@ -1472,14 +1470,11 @@ fn genLazy(func: *Func, lazy_sym: link.File.LazySymbol) InnerError!void {
 
             _ = try func.addInst(.{
                 .tag = .jalr,
-
-                .data = .{
-                    .i_type = .{
-                        .rd = .zero,
-                        .rs1 = .ra,
-                        .imm12 = Immediate.s(0),
-                    },
-                },
+                .data = .{ .i_type = .{
+                    .rd = .zero,
+                    .rs1 = .ra,
+                    .imm12 = Immediate.s(0),
+                } },
             });
         },
         else => return func.fail(
@@ -1629,8 +1624,8 @@ fn genBody(func: *Func, body: []const Air.Inst.Index) InnerError!void {
             .struct_field_val=> try func.airStructFieldVal(inst),
             .float_from_int  => try func.airFloatFromInt(inst),
             .int_from_float  => try func.airIntFromFloat(inst),
-            .cmpxchg_strong  => try func.airCmpxchg(inst),
-            .cmpxchg_weak    => try func.airCmpxchg(inst),
+            .cmpxchg_strong  => try func.airCmpxchg(inst, .strong),
+            .cmpxchg_weak    => try func.airCmpxchg(inst, .weak),
             .atomic_rmw      => try func.airAtomicRmw(inst),
             .atomic_load     => try func.airAtomicLoad(inst),
             .memcpy          => try func.airMemcpy(inst),
@@ -3527,7 +3522,8 @@ fn airWrapOptional(func: *Func, inst: Air.Inst.Index) !void {
         };
         defer if (pl_lock) |lock| func.register_manager.unlockReg(lock);
 
-        const opt_mcv = try func.allocRegOrMem(opt_ty, inst, true);
+        const opt_mcv = try func.allocRegOrMem(opt_ty, inst, false);
+        try func.genCopy(pl_ty, opt_mcv, pl_mcv);
 
         if (!same_repr) {
             const pl_abi_size: i32 = @intCast(pl_ty.abiSize(pt));
@@ -3541,18 +3537,6 @@ fn airWrapOptional(func: *Func, inst: Air.Inst.Index) !void {
                         .{ .immediate = 1 },
                     );
                 },
-
-                .register => |opt_reg| {
-                    try func.genBinOp(
-                        .shl,
-                        .{ .immediate = 1 },
-                        Type.u64,
-                        .{ .immediate = 32 },
-                        Type.u64,
-                        opt_reg,
-                    );
-                    try func.genCopy(pl_ty, opt_mcv, pl_mcv);
-                },
                 else => unreachable,
             }
         }
@@ -5800,6 +5784,7 @@ fn performReloc(func: *Func, inst: Mir.Inst.Index) void {
 
     switch (tag) {
         .beq,
+        .bne,
         => func.mir_instructions.items(.data)[inst].b_type.inst = target,
         .jal => func.mir_instructions.items(.data)[inst].j_type.inst = target,
         .pseudo_j => func.mir_instructions.items(.data)[inst].j_type.inst = target,
@@ -6047,30 +6032,85 @@ fn airAsm(func: *Func, inst: Air.Inst.Index) !void {
         }
     }
 
+    const Label = struct {
+        target: Mir.Inst.Index = undefined,
+        pending_relocs: std.ArrayListUnmanaged(Mir.Inst.Index) = .{},
+
+        const Kind = enum { definition, reference };
+
+        fn isValid(kind: Kind, name: []const u8) bool {
+            for (name, 0..) |c, i| switch (c) {
+                else => return false,
+                '$' => if (i == 0) return false,
+                '.' => {},
+                '0'...'9' => if (i == 0) switch (kind) {
+                    .definition => if (name.len != 1) return false,
+                    .reference => {
+                        if (name.len != 2) return false;
+                        switch (name[1]) {
+                            else => return false,
+                            'B', 'F', 'b', 'f' => {},
+                        }
+                    },
+                },
+                '@', 'A'...'Z', '_', 'a'...'z' => {},
+            };
+            return name.len > 0;
+        }
+    };
+    var labels: std.StringHashMapUnmanaged(Label) = .{};
+    defer {
+        var label_it = labels.valueIterator();
+        while (label_it.next()) |label| label.pending_relocs.deinit(func.gpa);
+        labels.deinit(func.gpa);
+    }
+
     const asm_source = std.mem.sliceAsBytes(func.air.extra[extra_i..])[0..extra.data.source_len];
     var line_it = mem.tokenizeAny(u8, asm_source, "\n\r;");
     next_line: while (line_it.next()) |line| {
         var mnem_it = mem.tokenizeAny(u8, line, " \t");
-        const instruction: union(enum) { mnem: Mnemonic, pseudo: Pseudo } = while (mnem_it.next()) |mnem_str| {
+        const mnem_str = while (mnem_it.next()) |mnem_str| {
             if (mem.startsWith(u8, mnem_str, "#")) continue :next_line;
             if (mem.startsWith(u8, mnem_str, "//")) continue :next_line;
-            if (std.meta.stringToEnum(Mnemonic, mnem_str)) |mnem| {
-                break .{ .mnem = mnem };
-            } else if (std.meta.stringToEnum(Pseudo, mnem_str)) |pseudo| {
-                break .{ .pseudo = pseudo };
-            } else return func.fail("TODO: airAsm labels, found '{s}'", .{mnem_str});
+            if (!mem.endsWith(u8, mnem_str, ":")) break mnem_str;
+            const label_name = mnem_str[0 .. mnem_str.len - ":".len];
+            if (!Label.isValid(.definition, label_name))
+                return func.fail("invalid label: '{s}'", .{label_name});
+
+            const label_gop = try labels.getOrPut(func.gpa, label_name);
+            if (!label_gop.found_existing) label_gop.value_ptr.* = .{} else {
+                const anon = std.ascii.isDigit(label_name[0]);
+                if (!anon and label_gop.value_ptr.pending_relocs.items.len == 0)
+                    return func.fail("redefined label: '{s}'", .{label_name});
+                for (label_gop.value_ptr.pending_relocs.items) |pending_reloc|
+                    func.performReloc(pending_reloc);
+                if (anon)
+                    label_gop.value_ptr.pending_relocs.clearRetainingCapacity()
+                else
+                    label_gop.value_ptr.pending_relocs.clearAndFree(func.gpa);
+            }
+            label_gop.value_ptr.target = @intCast(func.mir_instructions.len);
         } else continue;
 
+        const instruction: union(enum) { mnem: Mnemonic, pseudo: Pseudo } =
+            if (std.meta.stringToEnum(Mnemonic, mnem_str)) |mnem|
+            .{ .mnem = mnem }
+        else if (std.meta.stringToEnum(Pseudo, mnem_str)) |pseudo|
+            .{ .pseudo = pseudo }
+        else
+            return func.fail("invalid mnem str '{s}'", .{mnem_str});
+
         const Operand = union(enum) {
             none,
             reg: Register,
             imm: Immediate,
+            inst: Mir.Inst.Index,
             sym: SymbolOffset,
         };
 
         var ops: [4]Operand = .{.none} ** 4;
         var last_op = false;
-        var op_it = mem.splitScalar(u8, mnem_it.rest(), ',');
+        var op_it = mem.splitAny(u8, mnem_it.rest(), ",(");
         next_op: for (&ops) |*op| {
             const op_str = while (!last_op) {
                 const full_str = op_it.next() orelse break :next_op;
@@ -6109,6 +6149,25 @@ fn airAsm(func: *Func, inst: Air.Inst.Index) !void {
                         return func.fail("invalid modified '{s}'", .{modifier}),
                     else => return func.fail("invalid constraint: '{s}'", .{op_str}),
                 };
+            } else if (mem.endsWith(u8, op_str, ")")) {
+                const reg = op_str[0 .. op_str.len - ")".len];
+                const addr_reg = parseRegName(reg) orelse
+                    return func.fail("expected valid register, found '{s}'", .{reg});
+
+                op.* = .{ .reg = addr_reg };
+            } else if (Label.isValid(.reference, op_str)) {
+                const anon = std.ascii.isDigit(op_str[0]);
+                const label_gop = try labels.getOrPut(func.gpa, op_str[0..if (anon) 1 else op_str.len]);
+                if (!label_gop.found_existing) label_gop.value_ptr.* = .{};
+                if (anon and (op_str[1] == 'b' or op_str[1] == 'B') and !label_gop.found_existing)
+                    return func.fail("undefined label: '{s}'", .{op_str});
+                const pending_relocs = &label_gop.value_ptr.pending_relocs;
+                if (if (anon)
+                    op_str[1] == 'f' or op_str[1] == 'F'
+                else
+                    !label_gop.found_existing or pending_relocs.items.len > 0)
+                    try pending_relocs.append(func.gpa, @intCast(func.mir_instructions.len));
+                op.* = .{ .inst = label_gop.value_ptr.target };
             } else return func.fail("invalid operand: '{s}'", .{op_str});
         } else if (op_it.next()) |op_str| return func.fail("extra operand: '{s}'", .{op_str});
 
@@ -6131,6 +6190,39 @@ fn airAsm(func: *Func, inst: Air.Inst.Index) !void {
                             }),
                             else => error.InvalidInstruction,
                         },
+                        .imm => |imm1| switch (ops[2]) {
+                            .reg => |reg2| switch (mnem) {
+                                .sd => try func.addInst(.{
+                                    .tag = mnem,
+                                    .data = .{ .i_type = .{
+                                        .rd = reg2,
+                                        .rs1 = reg1,
+                                        .imm12 = imm1,
+                                    } },
+                                }),
+                                .ld => try func.addInst(.{
+                                    .tag = mnem,
+                                    .data = .{ .i_type = .{
+                                        .rd = reg1,
+                                        .rs1 = reg2,
+                                        .imm12 = imm1,
+                                    } },
+                                }),
+                                else => error.InvalidInstruction,
+                            },
+                            else => error.InvalidInstruction,
+                        },
+                        .none => switch (mnem) {
+                            .jalr => try func.addInst(.{
+                                .tag = mnem,
+                                .data = .{ .i_type = .{
+                                    .rd = .zero,
+                                    .rs1 = reg1,
+                                    .imm12 = Immediate.s(0),
+                                } },
+                            }),
+                            else => error.InvalidInstruction,
+                        },
                         else => error.InvalidInstruction,
                     },
                     else => error.InvalidInstruction,
@@ -6196,6 +6288,28 @@ fn airAsm(func: *Func, inst: Air.Inst.Index) !void {
                             } },
                         });
                     },
+                    .ret => _ = try func.addInst(.{
+                        .tag = .jalr,
+                        .data = .{ .i_type = .{
+                            .rd = .zero,
+                            .rs1 = .ra,
+                            .imm12 = Immediate.s(0),
+                        } },
+                    }),
+                    .beqz => blk: {
+                        if (ops[0] != .reg or ops[1] != .inst) {
+                            break :blk error.InvalidInstruction;
+                        }
+
+                        _ = try func.addInst(.{
+                            .tag = .beq,
+                            .data = .{ .b_type = .{
+                                .rs1 = ops[0].reg,
+                                .rs2 = .zero,
+                                .inst = ops[1].inst,
+                            } },
+                        });
+                    },
                 })) catch |err| {
                     switch (err) {
                         error.InvalidInstruction => return func.fail(
@@ -6215,6 +6329,10 @@ fn airAsm(func: *Func, inst: Air.Inst.Index) !void {
         }
     }
 
+    var label_it = labels.iterator();
+    while (label_it.next()) |label| if (label.value_ptr.pending_relocs.items.len > 0)
+        return func.fail("undefined label: '{s}'", .{label.key_ptr.*});
+
     for (outputs, args.items[0..outputs.len]) |output, arg_mcv| {
         const extra_bytes = mem.sliceAsBytes(func.air.extra[outputs_extra_i..]);
         const constraint =
@@ -7203,14 +7321,123 @@ fn airIntFromFloat(func: *Func, inst: Air.Inst.Index) !void {
     return func.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
 
-fn airCmpxchg(func: *Func, inst: Air.Inst.Index) !void {
+fn airCmpxchg(func: *Func, inst: Air.Inst.Index, strength: enum { weak, strong }) !void {
+    _ = strength; // TODO: do something with this
+
+    const pt = func.pt;
     const ty_pl = func.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
-    const extra = func.air.extraData(Air.Block, ty_pl.payload);
-    _ = extra;
-    return func.fail("TODO implement airCmpxchg for {}", .{
-        func.target.cpu.arch,
+    const extra = func.air.extraData(Air.Cmpxchg, ty_pl.payload).data;
+
+    const ptr_ty = func.typeOf(extra.ptr);
+    const val_ty = func.typeOf(extra.expected_value);
+    const val_abi_size: u32 = @intCast(val_ty.abiSize(pt));
+
+    switch (val_abi_size) {
+        1, 2, 4, 8 => {},
+        else => return func.fail("TODO: airCmpxchg Int size {}", .{val_abi_size}),
+    }
+
+    const succ_order: struct { aq: Mir.Barrier, rl: Mir.Barrier } = switch (extra.successOrder()) {
+        .unordered,
+        .release,
+        .acq_rel,
+        => unreachable,
+
+        .monotonic => .{ .aq = .none, .rl = .none },
+        .acquire => .{ .aq = .aq, .rl = .none },
+        .seq_cst => .{ .aq = .aq, .rl = .rl },
+    };
+
+    const ptr_mcv = try func.resolveInst(extra.ptr);
+    const ptr_reg, const ptr_lock = try func.promoteReg(ptr_ty, ptr_mcv);
+    defer if (ptr_lock) |lock| func.register_manager.unlockReg(lock);
+
+    const exp_mcv = try func.resolveInst(extra.expected_value);
+    const exp_reg, const exp_lock = try func.promoteReg(val_ty, exp_mcv);
+    defer if (exp_lock) |lock| func.register_manager.unlockReg(lock);
+    try func.truncateRegister(val_ty, exp_reg);
+
+    const new_mcv = try func.resolveInst(extra.new_value);
+    const new_reg, const new_lock = try func.promoteReg(val_ty, new_mcv);
+    defer if (new_lock) |lock| func.register_manager.unlockReg(lock);
+    try func.truncateRegister(val_ty, new_reg);
+
+    const branch_reg, const branch_lock = try func.allocReg(.int);
+    defer func.register_manager.unlockReg(branch_lock);
+
+    const fallthrough_reg, const fallthrough_lock = try func.allocReg(.int);
+    defer func.register_manager.unlockReg(fallthrough_lock);
+
+    const jump_back = try func.addInst(.{
+        .tag = if (val_ty.bitSize(pt) <= 32) .lrw else .lrd,
+        .data = .{ .amo = .{
+            .aq = succ_order.aq,
+            .rl = succ_order.rl,
+            .rd = branch_reg,
+            .rs1 = ptr_reg,
+            .rs2 = .zero,
+        } },
     });
-    // return func.finishAir(inst, result, .{ extra.ptr, extra.expected_value, extra.new_value });
+    try func.truncateRegister(val_ty, branch_reg);
+
+    const jump_forward = try func.addInst(.{
+        .tag = .bne,
+        .data = .{ .b_type = .{
+            .rs1 = branch_reg,
+            .rs2 = exp_reg,
+            .inst = undefined,
+        } },
+    });
+
+    _ = try func.addInst(.{
+        .tag = if (val_ty.bitSize(pt) <= 32) .scw else .scd,
+        .data = .{ .amo = .{
+            .aq = .none,
+            .rl = succ_order.rl,
+            .rd = fallthrough_reg,
+            .rs1 = ptr_reg,
+            .rs2 = new_reg,
+        } },
+    });
+    try func.truncateRegister(Type.bool, fallthrough_reg);
+
+    _ = try func.addInst(.{
+        .tag = .bne,
+        .data = .{ .b_type = .{
+            .rs1 = fallthrough_reg,
+            .rs2 = .zero,
+            .inst = jump_back,
+        } },
+    });
+
+    func.performReloc(jump_forward);
+
+    const result: MCValue = if (func.liveness.isUnused(inst)) .unreach else result: {
+        const dst_mcv = try func.allocRegOrMem(func.typeOfIndex(inst), inst, false);
+
+        const tmp_reg, const tmp_lock = try func.allocReg(.int);
+        defer func.register_manager.unlockReg(tmp_lock);
+
+        try func.genBinOp(
+            .cmp_neq,
+            .{ .register = branch_reg },
+            val_ty,
+            .{ .register = exp_reg },
+            val_ty,
+            tmp_reg,
+        );
+
+        try func.genCopy(val_ty, dst_mcv, .{ .register = branch_reg });
+        try func.genCopy(
+            Type.bool,
+            dst_mcv.address().offset(@intCast(val_abi_size)).deref(),
+            .{ .register = tmp_reg },
+        );
+
+        break :result dst_mcv;
+    };
+
+    return func.finishAir(inst, result, .{ extra.ptr, extra.expected_value, extra.new_value });
 }
 
 fn airAtomicRmw(func: *Func, inst: Air.Inst.Index) !void {
@@ -7234,8 +7461,8 @@ fn airAtomicRmw(func: *Func, inst: Air.Inst.Index) !void {
             return func.fail("TODO: airAtomicRmw non-pow 2", .{});
 
         switch (val_ty.zigTypeTag(pt.zcu)) {
-            .Int => {},
-            inline .Bool, .Float, .Enum, .Pointer => |ty| return func.fail("TODO: airAtomicRmw {s}", .{@tagName(ty)}),
+            .Enum, .Int => {},
+            inline .Bool, .Float, .Pointer => |ty| return func.fail("TODO: airAtomicRmw {s}", .{@tagName(ty)}),
             else => unreachable,
         }
 
src/arch/riscv64/mnem.zig
@@ -252,4 +252,6 @@ pub const Pseudo = enum(u8) {
     li,
     mv,
     tail,
+    beqz,
+    ret,
 };
test/behavior/atomics.zig
@@ -15,7 +15,6 @@ test "cmpxchg" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     try testCmpxchg();
     try comptime testCmpxchg();
@@ -108,7 +107,6 @@ test "cmpxchg with ignored result" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     var x: i32 = 1234;
 
@@ -153,7 +151,6 @@ test "cmpxchg on a global variable" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch == .aarch64) {
         // https://github.com/ziglang/zig/issues/10627