Commit 0a7bdc0077

Andrew Kelley <andrew@ziglang.org>
2019-02-09 20:44:33
implement vector addition with safety checking
this would work if @llvm.sadd.with.overflow supported vectors, which it does in trunk. but it does not support them in llvm 7 or even in llvm 8 release branch. so the next commit after this will have to do a different strategy, but when llvm 9 comes out it may be worth coming back to this one.
1 parent a8a63fe
src/all_types.hpp
@@ -1538,6 +1538,8 @@ enum ZigLLVMFnId {
     ZigLLVMFnIdBitReverse,
 };
 
+// There are a bunch of places in code that rely on these values being in
+// exactly this order.
 enum AddSubMul {
     AddSubMulAdd = 0,
     AddSubMulSub = 1,
@@ -1563,6 +1565,7 @@ struct ZigLLVMFnKey {
         struct {
             AddSubMul add_sub_mul;
             uint32_t bit_count;
+            uint32_t vector_len; // 0 means not a vector
             bool is_signed;
         } overflow_arithmetic;
         struct {
src/analyze.cpp
@@ -6361,7 +6361,8 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) {
         case ZigLLVMFnIdOverflowArithmetic:
             return ((uint32_t)(x.data.overflow_arithmetic.bit_count) * 87135777) +
                 ((uint32_t)(x.data.overflow_arithmetic.add_sub_mul) * 31640542) +
-                ((uint32_t)(x.data.overflow_arithmetic.is_signed) ? 1062315172 : 314955820);
+                ((uint32_t)(x.data.overflow_arithmetic.is_signed) ? 1062315172 : 314955820) +
+                x.data.overflow_arithmetic.vector_len * 1435156945;
     }
     zig_unreachable();
 }
@@ -6387,7 +6388,8 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) {
         case ZigLLVMFnIdOverflowArithmetic:
             return (a.data.overflow_arithmetic.bit_count == b.data.overflow_arithmetic.bit_count) &&
                 (a.data.overflow_arithmetic.add_sub_mul == b.data.overflow_arithmetic.add_sub_mul) &&
-                (a.data.overflow_arithmetic.is_signed == b.data.overflow_arithmetic.is_signed);
+                (a.data.overflow_arithmetic.is_signed == b.data.overflow_arithmetic.is_signed) &&
+                (a.data.overflow_arithmetic.vector_len == b.data.overflow_arithmetic.vector_len);
     }
     zig_unreachable();
 }
src/codegen.cpp
@@ -715,38 +715,59 @@ static void clear_debug_source_node(CodeGen *g) {
     ZigLLVMClearCurrentDebugLocation(g->builder);
 }
 
-static LLVMValueRef get_arithmetic_overflow_fn(CodeGen *g, ZigType *type_entry,
+static LLVMValueRef get_arithmetic_overflow_fn(CodeGen *g, ZigType *operand_type,
         const char *signed_name, const char *unsigned_name)
 {
+    ZigType *int_type = (operand_type->id == ZigTypeIdVector) ? operand_type->data.vector.elem_type : operand_type;
     char fn_name[64];
 
-    assert(type_entry->id == ZigTypeIdInt);
-    const char *signed_str = type_entry->data.integral.is_signed ? signed_name : unsigned_name;
-    sprintf(fn_name, "llvm.%s.with.overflow.i%" PRIu32, signed_str, type_entry->data.integral.bit_count);
+    assert(int_type->id == ZigTypeIdInt);
+    const char *signed_str = int_type->data.integral.is_signed ? signed_name : unsigned_name;
 
-    LLVMTypeRef return_elem_types[] = {
-        type_entry->type_ref,
-        LLVMInt1Type(),
-    };
     LLVMTypeRef param_types[] = {
-        type_entry->type_ref,
-        type_entry->type_ref,
+        operand_type->type_ref,
+        operand_type->type_ref,
     };
-    LLVMTypeRef return_struct_type = LLVMStructType(return_elem_types, 2, false);
-    LLVMTypeRef fn_type = LLVMFunctionType(return_struct_type, param_types, 2, false);
-    LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type);
-    assert(LLVMGetIntrinsicID(fn_val));
-    return fn_val;
+
+    if (operand_type->id == ZigTypeIdVector) {
+        sprintf(fn_name, "llvm.%s.with.overflow.v%" PRIu32 "i%" PRIu32, signed_str,
+                operand_type->data.vector.len, int_type->data.integral.bit_count);
+
+        LLVMTypeRef return_elem_types[] = {
+            operand_type->type_ref,
+            LLVMVectorType(LLVMInt1Type(), operand_type->data.vector.len),
+        };
+        LLVMTypeRef return_struct_type = LLVMStructType(return_elem_types, 2, false);
+        LLVMTypeRef fn_type = LLVMFunctionType(return_struct_type, param_types, 2, false);
+        LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type);
+        assert(LLVMGetIntrinsicID(fn_val));
+        return fn_val;
+    } else {
+        sprintf(fn_name, "llvm.%s.with.overflow.i%" PRIu32, signed_str, int_type->data.integral.bit_count);
+
+        LLVMTypeRef return_elem_types[] = {
+            operand_type->type_ref,
+            LLVMInt1Type(),
+        };
+        LLVMTypeRef return_struct_type = LLVMStructType(return_elem_types, 2, false);
+        LLVMTypeRef fn_type = LLVMFunctionType(return_struct_type, param_types, 2, false);
+        LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type);
+        assert(LLVMGetIntrinsicID(fn_val));
+        return fn_val;
+    }
 }
 
-static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *type_entry, AddSubMul add_sub_mul) {
-    assert(type_entry->id == ZigTypeIdInt);
+static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *operand_type, AddSubMul add_sub_mul) {
+    ZigType *int_type = (operand_type->id == ZigTypeIdVector) ? operand_type->data.vector.elem_type : operand_type;
+    assert(int_type->id == ZigTypeIdInt);
 
     ZigLLVMFnKey key = {};
     key.id = ZigLLVMFnIdOverflowArithmetic;
-    key.data.overflow_arithmetic.is_signed = type_entry->data.integral.is_signed;
+    key.data.overflow_arithmetic.is_signed = int_type->data.integral.is_signed;
     key.data.overflow_arithmetic.add_sub_mul = add_sub_mul;
-    key.data.overflow_arithmetic.bit_count = (uint32_t)type_entry->data.integral.bit_count;
+    key.data.overflow_arithmetic.bit_count = (uint32_t)int_type->data.integral.bit_count;
+    key.data.overflow_arithmetic.vector_len = (operand_type->id == ZigTypeIdVector) ?
+        operand_type->data.vector.len : 0;
 
     auto existing_entry = g->llvm_fn_table.maybe_get(key);
     if (existing_entry)
@@ -755,13 +776,13 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *type_entry, AddSubM
     LLVMValueRef fn_val;
     switch (add_sub_mul) {
         case AddSubMulAdd:
-            fn_val = get_arithmetic_overflow_fn(g, type_entry, "sadd", "uadd");
+            fn_val = get_arithmetic_overflow_fn(g, operand_type, "sadd", "uadd");
             break;
         case AddSubMulSub:
-            fn_val = get_arithmetic_overflow_fn(g, type_entry, "ssub", "usub");
+            fn_val = get_arithmetic_overflow_fn(g, operand_type, "ssub", "usub");
             break;
         case AddSubMulMul:
-            fn_val = get_arithmetic_overflow_fn(g, type_entry, "smul", "umul");
+            fn_val = get_arithmetic_overflow_fn(g, operand_type, "smul", "umul");
             break;
     }
 
@@ -1752,17 +1773,28 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z
     }
 }
 
-static LLVMValueRef gen_overflow_op(CodeGen *g, ZigType *type_entry, AddSubMul op,
+static LLVMValueRef gen_overflow_op(CodeGen *g, ZigType *operand_type, AddSubMul op,
         LLVMValueRef val1, LLVMValueRef val2)
 {
-    LLVMValueRef fn_val = get_int_overflow_fn(g, type_entry, op);
+    LLVMValueRef fn_val = get_int_overflow_fn(g, operand_type, op);
     LLVMValueRef params[] = {
         val1,
         val2,
     };
     LLVMValueRef result_struct = LLVMBuildCall(g->builder, fn_val, params, 2, "");
     LLVMValueRef result = LLVMBuildExtractValue(g->builder, result_struct, 0, "");
-    LLVMValueRef overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, "");
+
+    LLVMValueRef overflow_bit;
+    if (operand_type->id == ZigTypeIdVector) {
+        LLVMValueRef overflow_vector = LLVMBuildExtractValue(g->builder, result_struct, 1, "");
+        LLVMTypeRef bigger_int_type_ref = LLVMIntType(operand_type->data.vector.len);
+        LLVMValueRef bitcasted_overflow = LLVMBuildBitCast(g->builder, overflow_vector, bigger_int_type_ref, "");
+        LLVMValueRef zero = LLVMConstNull(bigger_int_type_ref);
+        overflow_bit = LLVMBuildICmp(g->builder, LLVMIntNE, bitcasted_overflow, zero, "");
+    } else {
+        overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, "");
+    }
+
     LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
     LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
     LLVMBuildCondBr(g->builder, overflow_bit, fail_block, ok_block);
@@ -2608,7 +2640,8 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
             (op_id == IrBinOpAdd || op_id == IrBinOpSub) &&
             op1->value.type->data.pointer.ptr_len == PtrLenUnknown)
     );
-    ZigType *type_entry = op1->value.type;
+    ZigType *operand_type = op1->value.type;
+    ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ? operand_type->data.vector.elem_type : operand_type;
 
     bool want_runtime_safety = bin_op_instruction->safety_check_on &&
         ir_want_runtime_safety(g, &bin_op_instruction->base);
@@ -2634,17 +2667,17 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
         case IrBinOpCmpGreaterThan:
         case IrBinOpCmpLessOrEq:
         case IrBinOpCmpGreaterOrEq:
-            if (type_entry->id == ZigTypeIdFloat) {
+            if (scalar_type->id == ZigTypeIdFloat) {
                 ZigLLVMSetFastMath(g->builder, ir_want_fast_math(g, &bin_op_instruction->base));
                 LLVMRealPredicate pred = cmp_op_to_real_predicate(op_id);
                 return LLVMBuildFCmp(g->builder, pred, op1_value, op2_value, "");
-            } else if (type_entry->id == ZigTypeIdInt) {
-                LLVMIntPredicate pred = cmp_op_to_int_predicate(op_id, type_entry->data.integral.is_signed);
+            } else if (scalar_type->id == ZigTypeIdInt) {
+                LLVMIntPredicate pred = cmp_op_to_int_predicate(op_id, scalar_type->data.integral.is_signed);
                 return LLVMBuildICmp(g->builder, pred, op1_value, op2_value, "");
-            } else if (type_entry->id == ZigTypeIdEnum ||
-                    type_entry->id == ZigTypeIdErrorSet ||
-                    type_entry->id == ZigTypeIdBool ||
-                    get_codegen_ptr_type(type_entry) != nullptr)
+            } else if (scalar_type->id == ZigTypeIdEnum ||
+                    scalar_type->id == ZigTypeIdErrorSet ||
+                    scalar_type->id == ZigTypeIdBool ||
+                    get_codegen_ptr_type(scalar_type) != nullptr)
             {
                 LLVMIntPredicate pred = cmp_op_to_int_predicate(op_id, false);
                 return LLVMBuildICmp(g->builder, pred, op1_value, op2_value, "");
@@ -2665,23 +2698,16 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
             static const BuildBinOpFunc signed_op[3] = { LLVMBuildNSWAdd, LLVMBuildNSWSub, LLVMBuildNSWMul };
             static const BuildBinOpFunc unsigned_op[3] = { LLVMBuildNUWAdd, LLVMBuildNUWSub, LLVMBuildNUWMul };
 
-            bool is_vector = type_entry->id == ZigTypeIdVector;
             bool is_wrapping = (op_id == IrBinOpSubWrap || op_id == IrBinOpAddWrap || op_id == IrBinOpMultWrap);
             AddSubMul add_sub_mul =
                 op_id == IrBinOpAdd || op_id == IrBinOpAddWrap ? AddSubMulAdd :
                 op_id == IrBinOpSub || op_id == IrBinOpSubWrap ? AddSubMulSub :
                 AddSubMulMul;
 
-            // The code that is generated for vectors and scalars are the same,
-            // so we can just set type_entry to the vectors elem_type an avoid
-            // a lot of repeated code.
-            if (is_vector)
-                type_entry = type_entry->data.vector.elem_type;
-
-            if (type_entry->id == ZigTypeIdPointer) {
-                assert(type_entry->data.pointer.ptr_len == PtrLenUnknown);
+            if (scalar_type->id == ZigTypeIdPointer) {
+                assert(scalar_type->data.pointer.ptr_len == PtrLenUnknown);
                 LLVMValueRef subscript_value;
-                if (is_vector)
+                if (operand_type->id == ZigTypeIdVector)
                     zig_panic("TODO: Implement vector operations on pointers.");
 
                 switch (add_sub_mul) {
@@ -2697,17 +2723,15 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
 
                 // TODO runtime safety
                 return LLVMBuildInBoundsGEP(g->builder, op1_value, &subscript_value, 1, "");
-            } else if (type_entry->id == ZigTypeIdFloat) {
+            } else if (scalar_type->id == ZigTypeIdFloat) {
                 ZigLLVMSetFastMath(g->builder, ir_want_fast_math(g, &bin_op_instruction->base));
                 return float_op[add_sub_mul](g->builder, op1_value, op2_value, "");
-            } else if (type_entry->id == ZigTypeIdInt) {
+            } else if (scalar_type->id == ZigTypeIdInt) {
                 if (is_wrapping) {
                     return wrap_op[add_sub_mul](g->builder, op1_value, op2_value, "");
                 } else if (want_runtime_safety) {
-                    if (is_vector)
-                        zig_panic("TODO: Implement runtime safety vector operations.");
-                    return gen_overflow_op(g, type_entry, add_sub_mul, op1_value, op2_value);
-                } else if (type_entry->data.integral.is_signed) {
+                    return gen_overflow_op(g, operand_type, add_sub_mul, op1_value, op2_value);
+                } else if (scalar_type->data.integral.is_signed) {
                     return signed_op[add_sub_mul](g->builder, op1_value, op2_value, "");
                 } else {
                     return unsigned_op[add_sub_mul](g->builder, op1_value, op2_value, "");
@@ -2725,15 +2749,14 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
         case IrBinOpBitShiftLeftLossy:
         case IrBinOpBitShiftLeftExact:
             {
-                assert(type_entry->id == ZigTypeIdInt);
-                LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value.type,
-                        type_entry, op2_value);
+                assert(scalar_type->id == ZigTypeIdInt);
+                LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value.type, scalar_type, op2_value);
                 bool is_sloppy = (op_id == IrBinOpBitShiftLeftLossy);
                 if (is_sloppy) {
                     return LLVMBuildShl(g->builder, op1_value, op2_casted, "");
                 } else if (want_runtime_safety) {
-                    return gen_overflow_shl_op(g, type_entry, op1_value, op2_casted);
-                } else if (type_entry->data.integral.is_signed) {
+                    return gen_overflow_shl_op(g, scalar_type, op1_value, op2_casted);
+                } else if (scalar_type->data.integral.is_signed) {
                     return ZigLLVMBuildNSWShl(g->builder, op1_value, op2_casted, "");
                 } else {
                     return ZigLLVMBuildNUWShl(g->builder, op1_value, op2_casted, "");
@@ -2742,19 +2765,18 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
         case IrBinOpBitShiftRightLossy:
         case IrBinOpBitShiftRightExact:
             {
-                assert(type_entry->id == ZigTypeIdInt);
-                LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value.type,
-                        type_entry, op2_value);
+                assert(scalar_type->id == ZigTypeIdInt);
+                LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value.type, scalar_type, op2_value);
                 bool is_sloppy = (op_id == IrBinOpBitShiftRightLossy);
                 if (is_sloppy) {
-                    if (type_entry->data.integral.is_signed) {
+                    if (scalar_type->data.integral.is_signed) {
                         return LLVMBuildAShr(g->builder, op1_value, op2_casted, "");
                     } else {
                         return LLVMBuildLShr(g->builder, op1_value, op2_casted, "");
                     }
                 } else if (want_runtime_safety) {
-                    return gen_overflow_shr_op(g, type_entry, op1_value, op2_casted);
-                } else if (type_entry->data.integral.is_signed) {
+                    return gen_overflow_shr_op(g, scalar_type, op1_value, op2_casted);
+                } else if (scalar_type->data.integral.is_signed) {
                     return ZigLLVMBuildAShrExact(g->builder, op1_value, op2_casted, "");
                 } else {
                     return ZigLLVMBuildLShrExact(g->builder, op1_value, op2_casted, "");
@@ -2762,22 +2784,22 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
             }
         case IrBinOpDivUnspecified:
             return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base),
-                    op1_value, op2_value, type_entry, DivKindFloat);
+                    op1_value, op2_value, scalar_type, DivKindFloat);
         case IrBinOpDivExact:
             return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base),
-                    op1_value, op2_value, type_entry, DivKindExact);
+                    op1_value, op2_value, scalar_type, DivKindExact);
         case IrBinOpDivTrunc:
             return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base),
-                    op1_value, op2_value, type_entry, DivKindTrunc);
+                    op1_value, op2_value, scalar_type, DivKindTrunc);
         case IrBinOpDivFloor:
             return gen_div(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base),
-                    op1_value, op2_value, type_entry, DivKindFloor);
+                    op1_value, op2_value, scalar_type, DivKindFloor);
         case IrBinOpRemRem:
             return gen_rem(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base),
-                    op1_value, op2_value, type_entry, RemKindRem);
+                    op1_value, op2_value, scalar_type, RemKindRem);
         case IrBinOpRemMod:
             return gen_rem(g, want_runtime_safety, ir_want_fast_math(g, &bin_op_instruction->base),
-                    op1_value, op2_value, type_entry, RemKindMod);
+                    op1_value, op2_value, scalar_type, RemKindMod);
     }
     zig_unreachable();
 }