Commit b992ea1b07

Cody Tapscott <topolarity@tapscott.me>
2022-10-08 20:27:29
stage1: Rely on softfloat for `f16` on non-arm targets
1 parent 37c6fca
Changed files (1)
src
src/stage1/codegen.cpp
@@ -80,6 +80,7 @@ void codegen_set_strip(CodeGen *g, bool strip) {
     }
 }
 
+static LLVMValueRef get_soft_float_fn(CodeGen *g, const char *name, int param_count, LLVMTypeRef param_type, LLVMTypeRef return_type);
 static void render_const_val(CodeGen *g, ZigValue *const_val, const char *name);
 static void render_const_val_global(CodeGen *g, ZigValue *const_val, const char *name);
 static LLVMValueRef gen_const_val(CodeGen *g, ZigValue *const_val, const char *name);
@@ -1736,12 +1737,7 @@ static LLVMValueRef gen_soft_float_widen_or_shorten(CodeGen *g, ZigType *actual_
         }
     }
 
-    LLVMValueRef func_ref = LLVMGetNamedFunction(g->module, fn_name);
-    if (func_ref == nullptr) {
-        LLVMTypeRef fn_type = LLVMFunctionType(return_type, &param_type, 1, false);
-        func_ref = LLVMAddFunction(g->module, fn_name, fn_type);
-    }
-
+    LLVMValueRef func_ref = get_soft_float_fn(g, fn_name, 1, param_type, return_type);
     result = LLVMBuildCall2(g->builder, LLVMGlobalGetValueType(func_ref), func_ref, &expr_val, 1, "");
 
     // On non-Arm platforms we need to bitcast __trunc<>fhf2 result back to f16
@@ -1766,9 +1762,12 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z
     uint64_t wanted_bits;
     if (scalar_actual_type->id == ZigTypeIdFloat) {
 
-        if ((scalar_actual_type == g->builtin_types.entry_f80
+        if (((scalar_actual_type == g->builtin_types.entry_f80
             || scalar_wanted_type == g->builtin_types.entry_f80)
-         && !target_has_f80(g->zig_target))
+         && !target_has_f80(g->zig_target)) ||
+            ((scalar_actual_type == g->builtin_types.entry_f16
+            || scalar_wanted_type == g->builtin_types.entry_f16)
+         && !target_is_arm(g->zig_target)))
         {
             return gen_soft_float_widen_or_shorten(g, actual_type, wanted_type, expr_val);
         }
@@ -3100,6 +3099,7 @@ static LLVMValueRef gen_float_un_op(CodeGen *g, LLVMValueRef operand, ZigType *o
     ZigType *elem_type = operand_type->id == ZigTypeIdVector ? operand_type->data.vector.elem_type : operand_type;
     if ((elem_type == g->builtin_types.entry_f80 && !target_has_f80(g->zig_target)) ||
         (elem_type == g->builtin_types.entry_f128 && !target_long_double_is_f128(g->zig_target)) ||
+        (elem_type == g->builtin_types.entry_f16 && !target_is_arm(g->zig_target)) ||
         op == BuiltinFnIdTan)
     {
         return gen_soft_float_un_op(g, operand, operand_type, op);
@@ -3690,7 +3690,8 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, Stage1Air *executable,
     ZigType *operand_type = op1->value->type;
     ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ? operand_type->data.vector.elem_type : operand_type;
     if ((scalar_type == g->builtin_types.entry_f80 && !target_has_f80(g->zig_target)) ||
-        (scalar_type == g->builtin_types.entry_f128 && !target_long_double_is_f128(g->zig_target))) {
+        (scalar_type == g->builtin_types.entry_f128 && !target_long_double_is_f128(g->zig_target)) ||
+        (scalar_type == g->builtin_types.entry_f16 && !target_is_arm(g->zig_target))) {
         // LLVM incorrectly lowers the soft float calls for f128 as if they operated on `long double`.
         // On some targets this will be incorrect, so we manually lower the call ourselves.
         LLVMValueRef op1_value = ir_llvm_value(g, op1);
@@ -4024,7 +4025,8 @@ static LLVMValueRef ir_render_cast(CodeGen *g, Stage1Air *executable,
             assert(actual_type->id == ZigTypeIdInt);
             {
                 if ((wanted_type == g->builtin_types.entry_f80 && !target_has_f80(g->zig_target)) ||
-                    (wanted_type == g->builtin_types.entry_f128 && !target_long_double_is_f128(g->zig_target))) {
+                    (wanted_type == g->builtin_types.entry_f128 && !target_long_double_is_f128(g->zig_target)) ||
+                    (wanted_type == g->builtin_types.entry_f16 && !target_is_arm(g->zig_target))) {
                     return gen_soft_int_to_float_op(g, expr_val, actual_type, wanted_type);
                 } else {
                     if (actual_type->data.integral.is_signed) {
@@ -4042,7 +4044,8 @@ static LLVMValueRef ir_render_cast(CodeGen *g, Stage1Air *executable,
 
             LLVMValueRef result;
             if ((actual_type == g->builtin_types.entry_f80 && !target_has_f80(g->zig_target)) ||
-                (actual_type == g->builtin_types.entry_f128 && !target_long_double_is_f128(g->zig_target))) {
+                (actual_type == g->builtin_types.entry_f128 && !target_long_double_is_f128(g->zig_target)) ||
+                (actual_type == g->builtin_types.entry_f16 && !target_is_arm(g->zig_target))) {
                 result = gen_soft_float_to_int_op(g, expr_val, actual_type, wanted_type);
             } else {
                 if (wanted_type->data.integral.is_signed) {
@@ -4396,7 +4399,8 @@ static LLVMValueRef gen_negation(CodeGen *g, Stage1AirInst *inst, Stage1AirInst
         operand_type->data.vector.elem_type : operand_type;
 
     if ((scalar_type == g->builtin_types.entry_f80 && !target_has_f80(g->zig_target)) ||
-        (scalar_type == g->builtin_types.entry_f128 && !target_long_double_is_f128(g->zig_target))) {
+        (scalar_type == g->builtin_types.entry_f128 && !target_long_double_is_f128(g->zig_target)) ||
+        (scalar_type == g->builtin_types.entry_f16 && !target_is_arm(g->zig_target))) {
         return gen_soft_float_neg(g, operand_type, llvm_operand);
     }
 
@@ -7374,7 +7378,9 @@ static LLVMValueRef ir_render_soft_mul_add(CodeGen *g, Stage1Air *executable, St
     uint32_t vector_len = operand_type->id == ZigTypeIdVector ? operand_type->data.vector.len : 0;
 
     const char *fn_name;
-    if (float_type == g->builtin_types.entry_f32)
+    if (float_type == g->builtin_types.entry_f16)
+        fn_name = "__fmah";
+    else if (float_type == g->builtin_types.entry_f32)
         fn_name = "fmaf";
     else if (float_type == g->builtin_types.entry_f64)
         fn_name = "fma";
@@ -7385,13 +7391,8 @@ static LLVMValueRef ir_render_soft_mul_add(CodeGen *g, Stage1Air *executable, St
     else
         zig_unreachable();
 
-    LLVMValueRef func_ref = LLVMGetNamedFunction(g->module, fn_name);
-    if (func_ref == nullptr) {
-        LLVMTypeRef float_type_ref = float_type->llvm_type;
-        LLVMTypeRef params[3] = { float_type_ref, float_type_ref, float_type_ref };
-        LLVMTypeRef fn_type = LLVMFunctionType(float_type_ref, params, 3, false);
-        func_ref = LLVMAddFunction(g->module, fn_name, fn_type);
-    }
+    LLVMTypeRef float_type_ref = float_type->llvm_type;
+    LLVMValueRef func_ref = get_soft_float_fn(g, fn_name, 3, float_type_ref, float_type_ref); 
 
     LLVMValueRef op1 = ir_llvm_value(g, instruction->op1);
     LLVMValueRef op2 = ir_llvm_value(g, instruction->op2);
@@ -7421,7 +7422,8 @@ static LLVMValueRef ir_render_mul_add(CodeGen *g, Stage1Air *executable, Stage1A
     ZigType *operand_type = instruction->op1->value->type;
     operand_type = operand_type->id == ZigTypeIdVector ? operand_type->data.vector.elem_type : operand_type;
     if ((operand_type == g->builtin_types.entry_f80 && !target_has_f80(g->zig_target)) ||
-        (operand_type == g->builtin_types.entry_f128 && !target_long_double_is_f128(g->zig_target))) {
+        (operand_type == g->builtin_types.entry_f128 && !target_long_double_is_f128(g->zig_target)) ||
+        (operand_type == g->builtin_types.entry_f16 && !target_is_arm(g->zig_target))) {
         return ir_render_soft_mul_add(g, executable, instruction, operand_type);
     }
     LLVMValueRef op1 = ir_llvm_value(g, instruction->op1);
@@ -9740,7 +9742,12 @@ static void define_builtin_types(CodeGen *g) {
         }
     }
 
-    add_fp_entry(g, "f16", 16, LLVMHalfType(), &g->builtin_types.entry_f16);
+    if (target_is_arm(g->zig_target)) {
+        add_fp_entry(g, "f16", 16, LLVMHalfType(), &g->builtin_types.entry_f16);
+    } else {
+        ZigType *u16_ty = get_int_type(g, false, 16);
+        add_fp_entry(g, "f16", 16, get_llvm_type(g, u16_ty), &g->builtin_types.entry_f16);
+    }
     add_fp_entry(g, "f32", 32, LLVMFloatType(), &g->builtin_types.entry_f32);
     add_fp_entry(g, "f64", 64, LLVMDoubleType(), &g->builtin_types.entry_f64);
     add_fp_entry(g, "f128", 128, LLVMFP128Type(), &g->builtin_types.entry_f128);