Commit 4ee91bb8a8

Andrew Kelley <andrew@ziglang.org>
2021-10-05 23:20:12
stage1: work around LLVM's buggy fma lowering
* move fmaq from freestanding libc to compiler_rt, unconditionally exported weak_odr. * stage1: add fmaf, fmal, fmaq as symbols that compiler-rt might generate calls to. * stage1: lower `@mulAdd` directly to a call to `fmaq` instead of to the LLVM intrinsic because LLVM will lower it to `fmal` even when the target's `long double` is not equivalent to `f128`. This commit is intended to fix the test suite which is failing on the previous commit.
1 parent 5e153b5
Changed files (5)
lib/std/special/c_stage1.zig
@@ -656,10 +656,6 @@ export fn ceil(x: f64) f64 {
     return math.ceil(x);
 }
 
-export fn fmal(a: f128, b: f128, c: f128) f128 {
-    return math.fma(f128, a, b, c);
-}
-
 export fn fma(a: f64, b: f64, c: f64) f64 {
     return math.fma(f64, a, b, c);
 }
lib/std/special/compiler_rt.zig
@@ -616,9 +616,15 @@ comptime {
         @export(__mulodi4, .{ .name = "__mulodi4", .linkage = linkage });
 
         _ = @import("compiler_rt/atomics.zig");
+
+        @export(fmaq, .{ .name = "fmaq", .linkage = linkage });
     }
 }
 
+fn fmaq(a: f128, b: f128, c: f128) callconv(.C) f128 {
+    return std.math.fma(f128, a, b, c);
+}
+
 // Avoid dragging in the runtime safety mechanisms into this .o file,
 // unless we're trying to test this file.
 pub fn panic(msg: []const u8, error_return_trace: ?*std.builtin.StackTrace) noreturn {
src/stage1/codegen.cpp
@@ -57,6 +57,9 @@ static const char *symbols_that_llvm_depends_on[] = {
     "log10",
     "log2",
     "fma",
+    "fmaf",
+    "fmal",
+    "fmaq",
     "fabs",
     "minnum",
     "maxnum",
@@ -832,10 +835,25 @@ static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn
 
     bool is_vector = (type_entry->id == ZigTypeIdVector);
     ZigType *float_type = is_vector ? type_entry->data.vector.elem_type : type_entry;
+    uint32_t float_bits = float_type->data.floating.bit_count;
+
+    // LLVM incorrectly lowers the fma builtin for f128 to fmal, which is for
+    // `long double`. On some targets this will be correct; on others it will be incorrect.
+    if (fn_id == ZigLLVMFnIdFMA && float_bits == 128 &&
+        !target_long_double_is_f128(g->zig_target))
+    {
+        LLVMValueRef existing_llvm_fn = LLVMGetNamedFunction(g->module, "fmaq");
+        if (existing_llvm_fn != nullptr) return existing_llvm_fn;
+
+        LLVMTypeRef float_type_ref = get_llvm_type(g, type_entry);
+        LLVMTypeRef return_elem_types[3] = { float_type_ref, float_type_ref, float_type_ref };
+        LLVMTypeRef fn_type = LLVMFunctionType(float_type_ref, return_elem_types, 3, false);
+        return LLVMAddFunction(g->module, "fmaq", fn_type);
+    }
 
     ZigLLVMFnKey key = {};
     key.id = fn_id;
-    key.data.floating.bit_count = (uint32_t)float_type->data.floating.bit_count;
+    key.data.floating.bit_count = float_bits;
     key.data.floating.vector_len = is_vector ? (uint32_t)type_entry->data.vector.len : 0;
     key.data.floating.op = op;
 
@@ -861,11 +879,7 @@ static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn
     else
         sprintf(fn_name, "llvm.%s.f%" PRIu32, name, key.data.floating.bit_count);
     LLVMTypeRef float_type_ref = get_llvm_type(g, type_entry);
-    LLVMTypeRef return_elem_types[3] = {
-        float_type_ref,
-        float_type_ref,
-        float_type_ref,
-    };
+    LLVMTypeRef return_elem_types[3] = { float_type_ref, float_type_ref, float_type_ref };
     LLVMTypeRef fn_type = LLVMFunctionType(float_type_ref, return_elem_types, num_args, false);
     LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type);
     assert(LLVMGetIntrinsicID(fn_val));
@@ -6583,11 +6597,7 @@ static LLVMValueRef ir_render_mul_add(CodeGen *g, Stage1Air *executable, Stage1A
     assert(instruction->base.value->type->id == ZigTypeIdFloat ||
            instruction->base.value->type->id == ZigTypeIdVector);
     LLVMValueRef fn_val = get_float_fn(g, instruction->base.value->type, ZigLLVMFnIdFMA, BuiltinFnIdMulAdd);
-    LLVMValueRef args[3] = {
-        op1,
-        op2,
-        op3,
-    };
+    LLVMValueRef args[3] = { op1, op2, op3 };
     return LLVMBuildCall(g->builder, fn_val, args, 3, "");
 }
 
src/stage1/target.cpp
@@ -999,6 +999,22 @@ bool target_has_debug_info(const ZigTarget *target) {
     return !target_is_wasm(target);
 }
 
+bool target_long_double_is_f128(const ZigTarget *target) {
+    switch (target->arch) {
+        case ZigLLVM_riscv64:
+        case ZigLLVM_aarch64:
+        case ZigLLVM_aarch64_be:
+        case ZigLLVM_aarch64_32:
+        case ZigLLVM_systemz:
+        case ZigLLVM_mips64:
+        case ZigLLVM_mips64el:
+            return true;
+
+        default:
+            return false;
+    }
+}
+
 bool target_is_riscv(const ZigTarget *target) {
     return target->arch == ZigLLVM_riscv32 || target->arch == ZigLLVM_riscv64;
 }
src/stage1/target.hpp
@@ -79,6 +79,7 @@ bool target_is_riscv(const ZigTarget *target);
 bool target_is_sparc(const ZigTarget *target);
 bool target_is_android(const ZigTarget *target);
 bool target_has_debug_info(const ZigTarget *target);
+bool target_long_double_is_f128(const ZigTarget *target);
 
 uint32_t target_arch_pointer_bit_width(ZigLLVM_ArchType arch);
 uint32_t target_arch_largest_atomic_bits(ZigLLVM_ArchType arch);