Commit f6cdc94a50

LemonBoy <thatlemon@gmail.com>
2020-04-05 10:40:41
ir: Fix error checking for vector ops
The extra logic that's needed was lost during a refactoring, now it should be fine.
1 parent 0f964e1
Changed files (1)
src/codegen.cpp
@@ -155,7 +155,6 @@ static LLVMValueRef gen_await_early_return(CodeGen *g, IrInstGen *source_instr,
         LLVMValueRef target_frame_ptr, ZigType *result_type, ZigType *ptr_result_type,
         LLVMValueRef result_loc, bool non_async);
 static Error get_tmp_filename(CodeGen *g, Buf *out, Buf *suffix);
-static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val);
 
 static void addLLVMAttr(LLVMValueRef val, LLVMAttributeIndex attr_index, const char *attr_name) {
     unsigned kind_id = LLVMGetEnumAttributeKindForName(attr_name, strlen(attr_name));
@@ -2536,6 +2535,36 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir
     return nullptr;
 }
 
+enum class ScalarizePredicate {
+    // Returns true iff all the elements in the vector are 1.
+    // Equivalent to folding all the bits with `and`.
+    All,
+    // Returns true iff there's at least one element in the vector that is 1.
+    // Equivalent to folding all the bits with `or`.
+    Any,
+};
+
+// Collapses a <N x i1> vector into a single i1 according to the given predicate
+static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val, ScalarizePredicate predicate) {
+    assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind);
+    LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val)));
+    LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, "");
+
+    switch (predicate) {
+        case ScalarizePredicate::Any: {
+            LLVMValueRef all_zeros = LLVMConstNull(scalar_type);
+            return LLVMBuildICmp(g->builder, LLVMIntNE, casted, all_zeros, "");
+        }
+        case ScalarizePredicate::All: {
+            LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type);
+            return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, "");
+        }
+    }
+
+    zig_unreachable();
+}
+
+
 static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
     LLVMValueRef val1, LLVMValueRef val2)
 {
@@ -2560,7 +2589,7 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
     LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
     LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
     if (operand_type->id == ZigTypeIdVector) {
-        ok_bit = scalarize_cmp_result(g, ok_bit);
+        ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
     }
     LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
 
@@ -2591,7 +2620,7 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *operand_type,
     LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
     LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
     if (operand_type->id == ZigTypeIdVector) {
-        ok_bit = scalarize_cmp_result(g, ok_bit);
+        ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
     }
     LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
 
@@ -2647,16 +2676,6 @@ static LLVMValueRef bigint_to_llvm_const(LLVMTypeRef type_ref, BigInt *bigint) {
     }
 }
 
-// Collapses a <N x i1> vector into a single i1 whose value is 1 iff all the
-// vector elements are 1
-static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val) {
-    assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind);
-    LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val)));
-    LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type);
-    LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, "");
-    return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, "");
-}
-
 static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast_math,
     LLVMValueRef val1, LLVMValueRef val2, ZigType *operand_type, DivKind div_kind)
 {
@@ -2678,7 +2697,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
         }
 
         if (operand_type->id == ZigTypeIdVector) {
-            is_zero_bit = scalarize_cmp_result(g, is_zero_bit);
+            is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
         }
 
         LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail");
@@ -2703,7 +2722,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
             LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, "");
             LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, "");
             if (operand_type->id == ZigTypeIdVector) {
-                overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit);
+                overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit, ScalarizePredicate::Any);
             }
             LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block);
 
@@ -2728,7 +2747,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
                     LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
                     LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, "");
                     if (operand_type->id == ZigTypeIdVector) {
-                        ok_bit = scalarize_cmp_result(g, ok_bit);
+                        ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
                     }
                     LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
 
@@ -2745,7 +2764,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
                     LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncEnd");
                     LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, "");
                     if (operand_type->id == ZigTypeIdVector) {
-                        ltz = scalarize_cmp_result(g, ltz);
+                        ltz = scalarize_cmp_result(g, ltz, ScalarizePredicate::Any);
                     }
                     LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block);
 
@@ -2797,7 +2816,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
                 LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
                 LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, "");
                 if (operand_type->id == ZigTypeIdVector) {
-                    ok_bit = scalarize_cmp_result(g, ok_bit);
+                    ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
                 }
                 LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
 
@@ -2861,7 +2880,7 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
         }
 
         if (operand_type->id == ZigTypeIdVector) {
-            is_zero_bit = scalarize_cmp_result(g, is_zero_bit);
+            is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
         }
 
         LLVMBasicBlockRef rem_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroOk");
@@ -2918,7 +2937,7 @@ static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type
         LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk");
         LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
         if (rhs_type->id == ZigTypeIdVector) {
-            less_than_bit = scalarize_cmp_result(g, less_than_bit);
+            less_than_bit = scalarize_cmp_result(g, less_than_bit, ScalarizePredicate::Any);
         }
         LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block);