Commit b5dedd7c00

Jacob Young <jacobly0@users.noreply.github.com>
2023-10-08 10:04:04
x86_64: implement `@mulAdd` of floats for baseline
1 parent 35c9b71
Changed files (2)
src
arch
test
behavior
src/arch/x86_64/CodeGen.zig
@@ -12721,142 +12721,165 @@ fn airMulAdd(self: *Self, inst: Air.Inst.Index) !void {
     const extra = self.air.extraData(Air.Bin, pl_op.payload).data;
     const ty = self.typeOfIndex(inst);
 
-    if (!self.hasFeature(.fma)) return self.fail("TODO implement airMulAdd for {}", .{ty.fmt(mod)});
-
     const ops = [3]Air.Inst.Ref{ extra.lhs, extra.rhs, pl_op.operand };
-    var mcvs: [3]MCValue = undefined;
-    var locks = [1]?RegisterManager.RegisterLock{null} ** 3;
-    defer for (locks) |reg_lock| if (reg_lock) |lock| self.register_manager.unlockReg(lock);
-    var order = [1]u2{0} ** 3;
-    var unused = std.StaticBitSet(3).initFull();
-    for (ops, &mcvs, &locks, 0..) |op, *mcv, *lock, op_i| {
-        const op_index: u2 = @intCast(op_i);
-        mcv.* = try self.resolveInst(op);
-        if (unused.isSet(0) and mcv.isRegister() and self.reuseOperand(inst, op, op_index, mcv.*)) {
-            order[op_index] = 1;
-            unused.unset(0);
-        } else if (unused.isSet(2) and mcv.isMemory()) {
-            order[op_index] = 3;
-            unused.unset(2);
+    const result = result: {
+        if (switch (ty.scalarType(mod).floatBits(self.target.*)) {
+            16, 80, 128 => true,
+            32, 64 => !self.hasFeature(.fma),
+            else => unreachable,
+        }) {
+            if (ty.zigTypeTag(mod) != .Float) return self.fail("TODO implement airMulAdd for {}", .{
+                ty.fmt(mod),
+            });
+
+            var callee: ["__fma?".len]u8 = undefined;
+            break :result try self.genCall(.{ .lib = .{
+                .return_type = ty.toIntern(),
+                .param_types = &.{ ty.toIntern(), ty.toIntern(), ty.toIntern() },
+                .callee = std.fmt.bufPrint(&callee, "{s}fma{s}", .{
+                    floatLibcAbiPrefix(ty),
+                    floatLibcAbiSuffix(ty),
+                }) catch unreachable,
+            } }, &.{ ty, ty, ty }, &.{
+                .{ .air_ref = extra.lhs }, .{ .air_ref = extra.rhs }, .{ .air_ref = pl_op.operand },
+            });
         }
-        switch (mcv.*) {
-            .register => |reg| lock.* = self.register_manager.lockReg(reg),
-            else => {},
+
+        var mcvs: [3]MCValue = undefined;
+        var locks = [1]?RegisterManager.RegisterLock{null} ** 3;
+        defer for (locks) |reg_lock| if (reg_lock) |lock| self.register_manager.unlockReg(lock);
+        var order = [1]u2{0} ** 3;
+        var unused = std.StaticBitSet(3).initFull();
+        for (ops, &mcvs, &locks, 0..) |op, *mcv, *lock, op_i| {
+            const op_index: u2 = @intCast(op_i);
+            mcv.* = try self.resolveInst(op);
+            if (unused.isSet(0) and mcv.isRegister() and self.reuseOperand(inst, op, op_index, mcv.*)) {
+                order[op_index] = 1;
+                unused.unset(0);
+            } else if (unused.isSet(2) and mcv.isMemory()) {
+                order[op_index] = 3;
+                unused.unset(2);
+            }
+            switch (mcv.*) {
+                .register => |reg| lock.* = self.register_manager.lockReg(reg),
+                else => {},
+            }
+        }
+        for (&order, &mcvs, &locks) |*mop_index, *mcv, *lock| {
+            if (mop_index.* != 0) continue;
+            mop_index.* = 1 + @as(u2, @intCast(unused.toggleFirstSet().?));
+            if (mop_index.* > 1 and mcv.isRegister()) continue;
+            const reg = try self.copyToTmpRegister(ty, mcv.*);
+            mcv.* = .{ .register = reg };
+            if (lock.*) |old_lock| self.register_manager.unlockReg(old_lock);
+            lock.* = self.register_manager.lockRegAssumeUnused(reg);
         }
-    }
-    for (&order, &mcvs, &locks) |*mop_index, *mcv, *lock| {
-        if (mop_index.* != 0) continue;
-        mop_index.* = 1 + @as(u2, @intCast(unused.toggleFirstSet().?));
-        if (mop_index.* > 1 and mcv.isRegister()) continue;
-        const reg = try self.copyToTmpRegister(ty, mcv.*);
-        mcv.* = .{ .register = reg };
-        if (lock.*) |old_lock| self.register_manager.unlockReg(old_lock);
-        lock.* = self.register_manager.lockRegAssumeUnused(reg);
-    }
 
-    const mir_tag = @as(?Mir.Inst.FixedTag, if (mem.eql(u2, &order, &.{ 1, 3, 2 }) or
-        mem.eql(u2, &order, &.{ 3, 1, 2 }))
-        switch (ty.zigTypeTag(mod)) {
-            .Float => switch (ty.floatBits(self.target.*)) {
-                32 => .{ .v_ss, .fmadd132 },
-                64 => .{ .v_sd, .fmadd132 },
-                16, 80, 128 => null,
-                else => unreachable,
-            },
-            .Vector => switch (ty.childType(mod).zigTypeTag(mod)) {
-                .Float => switch (ty.childType(mod).floatBits(self.target.*)) {
-                    32 => switch (ty.vectorLen(mod)) {
-                        1 => .{ .v_ss, .fmadd132 },
-                        2...8 => .{ .v_ps, .fmadd132 },
-                        else => null,
-                    },
-                    64 => switch (ty.vectorLen(mod)) {
-                        1 => .{ .v_sd, .fmadd132 },
-                        2...4 => .{ .v_pd, .fmadd132 },
-                        else => null,
-                    },
+        const mir_tag = @as(?Mir.Inst.FixedTag, if (mem.eql(u2, &order, &.{ 1, 3, 2 }) or
+            mem.eql(u2, &order, &.{ 3, 1, 2 }))
+            switch (ty.zigTypeTag(mod)) {
+                .Float => switch (ty.floatBits(self.target.*)) {
+                    32 => .{ .v_ss, .fmadd132 },
+                    64 => .{ .v_sd, .fmadd132 },
                     16, 80, 128 => null,
                     else => unreachable,
                 },
-                else => unreachable,
-            },
-            else => unreachable,
-        }
-    else if (mem.eql(u2, &order, &.{ 2, 1, 3 }) or mem.eql(u2, &order, &.{ 1, 2, 3 }))
-        switch (ty.zigTypeTag(mod)) {
-            .Float => switch (ty.floatBits(self.target.*)) {
-                32 => .{ .v_ss, .fmadd213 },
-                64 => .{ .v_sd, .fmadd213 },
-                16, 80, 128 => null,
-                else => unreachable,
-            },
-            .Vector => switch (ty.childType(mod).zigTypeTag(mod)) {
-                .Float => switch (ty.childType(mod).floatBits(self.target.*)) {
-                    32 => switch (ty.vectorLen(mod)) {
-                        1 => .{ .v_ss, .fmadd213 },
-                        2...8 => .{ .v_ps, .fmadd213 },
-                        else => null,
-                    },
-                    64 => switch (ty.vectorLen(mod)) {
-                        1 => .{ .v_sd, .fmadd213 },
-                        2...4 => .{ .v_pd, .fmadd213 },
-                        else => null,
+                .Vector => switch (ty.childType(mod).zigTypeTag(mod)) {
+                    .Float => switch (ty.childType(mod).floatBits(self.target.*)) {
+                        32 => switch (ty.vectorLen(mod)) {
+                            1 => .{ .v_ss, .fmadd132 },
+                            2...8 => .{ .v_ps, .fmadd132 },
+                            else => null,
+                        },
+                        64 => switch (ty.vectorLen(mod)) {
+                            1 => .{ .v_sd, .fmadd132 },
+                            2...4 => .{ .v_pd, .fmadd132 },
+                            else => null,
+                        },
+                        16, 80, 128 => null,
+                        else => unreachable,
                     },
-                    16, 80, 128 => null,
                     else => unreachable,
                 },
                 else => unreachable,
-            },
-            else => unreachable,
-        }
-    else if (mem.eql(u2, &order, &.{ 2, 3, 1 }) or mem.eql(u2, &order, &.{ 3, 2, 1 }))
-        switch (ty.zigTypeTag(mod)) {
-            .Float => switch (ty.floatBits(self.target.*)) {
-                32 => .{ .v_ss, .fmadd231 },
-                64 => .{ .v_sd, .fmadd231 },
-                16, 80, 128 => null,
-                else => unreachable,
-            },
-            .Vector => switch (ty.childType(mod).zigTypeTag(mod)) {
-                .Float => switch (ty.childType(mod).floatBits(self.target.*)) {
-                    32 => switch (ty.vectorLen(mod)) {
-                        1 => .{ .v_ss, .fmadd231 },
-                        2...8 => .{ .v_ps, .fmadd231 },
-                        else => null,
-                    },
-                    64 => switch (ty.vectorLen(mod)) {
-                        1 => .{ .v_sd, .fmadd231 },
-                        2...4 => .{ .v_pd, .fmadd231 },
-                        else => null,
+            }
+        else if (mem.eql(u2, &order, &.{ 2, 1, 3 }) or mem.eql(u2, &order, &.{ 1, 2, 3 }))
+            switch (ty.zigTypeTag(mod)) {
+                .Float => switch (ty.floatBits(self.target.*)) {
+                    32 => .{ .v_ss, .fmadd213 },
+                    64 => .{ .v_sd, .fmadd213 },
+                    16, 80, 128 => null,
+                    else => unreachable,
+                },
+                .Vector => switch (ty.childType(mod).zigTypeTag(mod)) {
+                    .Float => switch (ty.childType(mod).floatBits(self.target.*)) {
+                        32 => switch (ty.vectorLen(mod)) {
+                            1 => .{ .v_ss, .fmadd213 },
+                            2...8 => .{ .v_ps, .fmadd213 },
+                            else => null,
+                        },
+                        64 => switch (ty.vectorLen(mod)) {
+                            1 => .{ .v_sd, .fmadd213 },
+                            2...4 => .{ .v_pd, .fmadd213 },
+                            else => null,
+                        },
+                        16, 80, 128 => null,
+                        else => unreachable,
                     },
+                    else => unreachable,
+                },
+                else => unreachable,
+            }
+        else if (mem.eql(u2, &order, &.{ 2, 3, 1 }) or mem.eql(u2, &order, &.{ 3, 2, 1 }))
+            switch (ty.zigTypeTag(mod)) {
+                .Float => switch (ty.floatBits(self.target.*)) {
+                    32 => .{ .v_ss, .fmadd231 },
+                    64 => .{ .v_sd, .fmadd231 },
                     16, 80, 128 => null,
                     else => unreachable,
                 },
+                .Vector => switch (ty.childType(mod).zigTypeTag(mod)) {
+                    .Float => switch (ty.childType(mod).floatBits(self.target.*)) {
+                        32 => switch (ty.vectorLen(mod)) {
+                            1 => .{ .v_ss, .fmadd231 },
+                            2...8 => .{ .v_ps, .fmadd231 },
+                            else => null,
+                        },
+                        64 => switch (ty.vectorLen(mod)) {
+                            1 => .{ .v_sd, .fmadd231 },
+                            2...4 => .{ .v_pd, .fmadd231 },
+                            else => null,
+                        },
+                        16, 80, 128 => null,
+                        else => unreachable,
+                    },
+                    else => unreachable,
+                },
                 else => unreachable,
-            },
-            else => unreachable,
-        }
-    else
-        unreachable) orelse return self.fail("TODO implement airMulAdd for {}", .{ty.fmt(mod)});
+            }
+        else
+            unreachable) orelse return self.fail("TODO implement airMulAdd for {}", .{ty.fmt(mod)});
 
-    var mops: [3]MCValue = undefined;
-    for (order, mcvs) |mop_index, mcv| mops[mop_index - 1] = mcv;
+        var mops: [3]MCValue = undefined;
+        for (order, mcvs) |mop_index, mcv| mops[mop_index - 1] = mcv;
 
-    const abi_size: u32 = @intCast(ty.abiSize(mod));
-    const mop1_reg = registerAlias(mops[0].getReg().?, abi_size);
-    const mop2_reg = registerAlias(mops[1].getReg().?, abi_size);
-    if (mops[2].isRegister()) try self.asmRegisterRegisterRegister(
-        mir_tag,
-        mop1_reg,
-        mop2_reg,
-        registerAlias(mops[2].getReg().?, abi_size),
-    ) else try self.asmRegisterRegisterMemory(
-        mir_tag,
-        mop1_reg,
-        mop2_reg,
-        mops[2].mem(Memory.PtrSize.fromSize(abi_size)),
-    );
-    return self.finishAir(inst, mops[0], ops);
+        const abi_size: u32 = @intCast(ty.abiSize(mod));
+        const mop1_reg = registerAlias(mops[0].getReg().?, abi_size);
+        const mop2_reg = registerAlias(mops[1].getReg().?, abi_size);
+        if (mops[2].isRegister()) try self.asmRegisterRegisterRegister(
+            mir_tag,
+            mop1_reg,
+            mop2_reg,
+            registerAlias(mops[2].getReg().?, abi_size),
+        ) else try self.asmRegisterRegisterMemory(
+            mir_tag,
+            mop1_reg,
+            mop2_reg,
+            mops[2].mem(Memory.PtrSize.fromSize(abi_size)),
+        );
+        break :result mops[0];
+    };
+    return self.finishAir(inst, result, ops);
 }
 
 fn airVaStart(self: *Self, inst: Air.Inst.Index) !void {
test/behavior/muladd.zig
@@ -32,11 +32,11 @@ fn testMulAdd() !void {
 }
 
 test "@mulAdd f16" {
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf) return error.SkipZigTest;
 
     try comptime testMulAdd16();
     try testMulAdd16();
@@ -50,12 +50,12 @@ fn testMulAdd16() !void {
 }
 
 test "@mulAdd f80" {
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_c and comptime builtin.cpu.arch.isArmOrThumb()) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf) return error.SkipZigTest;
 
     try comptime testMulAdd80();
     try testMulAdd80();
@@ -69,12 +69,12 @@ fn testMulAdd80() !void {
 }
 
 test "@mulAdd f128" {
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_c and comptime builtin.cpu.arch.isArmOrThumb()) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf) return error.SkipZigTest;
 
     try comptime testMulAdd128();
     try testMulAdd128();