Commit f6cdc94a50
Changed files (1)
src
src/codegen.cpp
@@ -155,7 +155,6 @@ static LLVMValueRef gen_await_early_return(CodeGen *g, IrInstGen *source_instr,
LLVMValueRef target_frame_ptr, ZigType *result_type, ZigType *ptr_result_type,
LLVMValueRef result_loc, bool non_async);
static Error get_tmp_filename(CodeGen *g, Buf *out, Buf *suffix);
-static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val);
static void addLLVMAttr(LLVMValueRef val, LLVMAttributeIndex attr_index, const char *attr_name) {
unsigned kind_id = LLVMGetEnumAttributeKindForName(attr_name, strlen(attr_name));
@@ -2536,6 +2535,36 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir
return nullptr;
}
+enum class ScalarizePredicate {
+ // Returns true iff all the elements in the vector are 1.
+ // Equivalent to folding all the bits with `and`.
+ All,
+ // Returns true iff there's at least one element in the vector that is 1.
+ // Equivalent to folding all the bits with `or`.
+ Any,
+};
+
+// Collapses a <N x i1> vector into a single i1 according to the given predicate
+static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val, ScalarizePredicate predicate) {
+ assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind);
+ LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val)));
+ LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, "");
+
+ switch (predicate) {
+ case ScalarizePredicate::Any: {
+ LLVMValueRef all_zeros = LLVMConstNull(scalar_type);
+ return LLVMBuildICmp(g->builder, LLVMIntNE, casted, all_zeros, "");
+ }
+ case ScalarizePredicate::All: {
+ LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type);
+ return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, "");
+ }
+ }
+
+ zig_unreachable();
+}
+
+
static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
LLVMValueRef val1, LLVMValueRef val2)
{
@@ -2560,7 +2589,7 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
if (operand_type->id == ZigTypeIdVector) {
- ok_bit = scalarize_cmp_result(g, ok_bit);
+ ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
@@ -2591,7 +2620,7 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *operand_type,
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
if (operand_type->id == ZigTypeIdVector) {
- ok_bit = scalarize_cmp_result(g, ok_bit);
+ ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
@@ -2647,16 +2676,6 @@ static LLVMValueRef bigint_to_llvm_const(LLVMTypeRef type_ref, BigInt *bigint) {
}
}
-// Collapses a <N x i1> vector into a single i1 whose value is 1 iff all the
-// vector elements are 1
-static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val) {
- assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind);
- LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val)));
- LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type);
- LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, "");
- return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, "");
-}
-
static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast_math,
LLVMValueRef val1, LLVMValueRef val2, ZigType *operand_type, DivKind div_kind)
{
@@ -2678,7 +2697,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
}
if (operand_type->id == ZigTypeIdVector) {
- is_zero_bit = scalarize_cmp_result(g, is_zero_bit);
+ is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
}
LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail");
@@ -2703,7 +2722,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, "");
LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, "");
if (operand_type->id == ZigTypeIdVector) {
- overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit);
+ overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit, ScalarizePredicate::Any);
}
LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block);
@@ -2728,7 +2747,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, "");
if (operand_type->id == ZigTypeIdVector) {
- ok_bit = scalarize_cmp_result(g, ok_bit);
+ ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
@@ -2745,7 +2764,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncEnd");
LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, "");
if (operand_type->id == ZigTypeIdVector) {
- ltz = scalarize_cmp_result(g, ltz);
+ ltz = scalarize_cmp_result(g, ltz, ScalarizePredicate::Any);
}
LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block);
@@ -2797,7 +2816,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, "");
if (operand_type->id == ZigTypeIdVector) {
- ok_bit = scalarize_cmp_result(g, ok_bit);
+ ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
@@ -2861,7 +2880,7 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
}
if (operand_type->id == ZigTypeIdVector) {
- is_zero_bit = scalarize_cmp_result(g, is_zero_bit);
+ is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
}
LLVMBasicBlockRef rem_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroOk");
@@ -2918,7 +2937,7 @@ static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk");
LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
if (rhs_type->id == ZigTypeIdVector) {
- less_than_bit = scalarize_cmp_result(g, less_than_bit);
+ less_than_bit = scalarize_cmp_result(g, less_than_bit, ScalarizePredicate::Any);
}
LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block);