Commit 373e21bb56

Andrew Kelley <andrew@ziglang.org>
2019-02-09 21:23:29
implement vector math safety with ext and trunc
1 parent 0a7bdc0
Changed files (3)
src
test
src/codegen.cpp
@@ -1773,25 +1773,46 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z
     }
 }
 
+typedef LLVMValueRef (*BuildBinOpFunc)(LLVMBuilderRef, LLVMValueRef, LLVMValueRef, const char *);
+// These are lookup table using the AddSubMul enum as the lookup.
+// If AddSubMul ever changes, then these tables will be out of
+// date.
+static const BuildBinOpFunc float_op[3] = { LLVMBuildFAdd, LLVMBuildFSub, LLVMBuildFMul };
+static const BuildBinOpFunc wrap_op[3] = { LLVMBuildAdd, LLVMBuildSub, LLVMBuildMul };
+static const BuildBinOpFunc signed_op[3] = { LLVMBuildNSWAdd, LLVMBuildNSWSub, LLVMBuildNSWMul };
+static const BuildBinOpFunc unsigned_op[3] = { LLVMBuildNUWAdd, LLVMBuildNUWSub, LLVMBuildNUWMul };
+
 static LLVMValueRef gen_overflow_op(CodeGen *g, ZigType *operand_type, AddSubMul op,
         LLVMValueRef val1, LLVMValueRef val2)
 {
-    LLVMValueRef fn_val = get_int_overflow_fn(g, operand_type, op);
-    LLVMValueRef params[] = {
-        val1,
-        val2,
-    };
-    LLVMValueRef result_struct = LLVMBuildCall(g->builder, fn_val, params, 2, "");
-    LLVMValueRef result = LLVMBuildExtractValue(g->builder, result_struct, 0, "");
-
     LLVMValueRef overflow_bit;
+    LLVMValueRef result;
+
     if (operand_type->id == ZigTypeIdVector) {
-        LLVMValueRef overflow_vector = LLVMBuildExtractValue(g->builder, result_struct, 1, "");
-        LLVMTypeRef bigger_int_type_ref = LLVMIntType(operand_type->data.vector.len);
-        LLVMValueRef bitcasted_overflow = LLVMBuildBitCast(g->builder, overflow_vector, bigger_int_type_ref, "");
-        LLVMValueRef zero = LLVMConstNull(bigger_int_type_ref);
+        ZigType *int_type = operand_type->data.vector.elem_type;
+        assert(int_type->id == ZigTypeIdInt);
+        LLVMTypeRef one_more_bit_int = LLVMIntType(int_type->data.integral.bit_count + 1);
+        LLVMTypeRef one_more_bit_int_vector = LLVMVectorType(one_more_bit_int, operand_type->data.vector.len);
+        const auto buildExtFn = int_type->data.integral.is_signed ? LLVMBuildSExt : LLVMBuildZExt;
+        LLVMValueRef extended1 = buildExtFn(g->builder, val1, one_more_bit_int_vector, "");
+        LLVMValueRef extended2 = buildExtFn(g->builder, val2, one_more_bit_int_vector, "");
+        LLVMValueRef extended_result = wrap_op[op](g->builder, extended1, extended2, "");
+        result = LLVMBuildTrunc(g->builder, extended_result, operand_type->type_ref, "");
+
+        LLVMValueRef re_extended_result = buildExtFn(g->builder, result, one_more_bit_int_vector, "");
+        LLVMValueRef overflow_vector = LLVMBuildICmp(g->builder, LLVMIntNE, extended_result, re_extended_result, "");
+        LLVMTypeRef bitcast_int_type = LLVMIntType(operand_type->data.vector.len);
+        LLVMValueRef bitcasted_overflow = LLVMBuildBitCast(g->builder, overflow_vector, bitcast_int_type, "");
+        LLVMValueRef zero = LLVMConstNull(bitcast_int_type);
         overflow_bit = LLVMBuildICmp(g->builder, LLVMIntNE, bitcasted_overflow, zero, "");
     } else {
+        LLVMValueRef fn_val = get_int_overflow_fn(g, operand_type, op);
+        LLVMValueRef params[] = {
+            val1,
+            val2,
+        };
+        LLVMValueRef result_struct = LLVMBuildCall(g->builder, fn_val, params, 2, "");
+        result = LLVMBuildExtractValue(g->builder, result_struct, 0, "");
         overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, "");
     }
 
@@ -2623,8 +2644,6 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
 
 }
 
-typedef LLVMValueRef (*BuildBinOpFunc)(LLVMBuilderRef, LLVMValueRef, LLVMValueRef, const char *);
-
 static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
         IrInstructionBinOp *bin_op_instruction)
 {
@@ -2690,14 +2709,6 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
         case IrBinOpAddWrap:
         case IrBinOpSub:
         case IrBinOpSubWrap: {
-            // These are lookup table using the AddSubMul enum as the lookup.
-            // If AddSubMul ever changes, then these tables will be out of
-            // date.
-            static const BuildBinOpFunc float_op[3] = { LLVMBuildFAdd, LLVMBuildFSub, LLVMBuildFMul };
-            static const BuildBinOpFunc wrap_op[3] = { LLVMBuildAdd, LLVMBuildSub, LLVMBuildMul };
-            static const BuildBinOpFunc signed_op[3] = { LLVMBuildNSWAdd, LLVMBuildNSWSub, LLVMBuildNSWMul };
-            static const BuildBinOpFunc unsigned_op[3] = { LLVMBuildNUWAdd, LLVMBuildNUWSub, LLVMBuildNUWMul };
-
             bool is_wrapping = (op_id == IrBinOpSubWrap || op_id == IrBinOpAddWrap || op_id == IrBinOpMultWrap);
             AddSubMul add_sub_mul =
                 op_id == IrBinOpAdd || op_id == IrBinOpAddWrap ? AddSubMulAdd :
test/stage1/behavior/math.zig
@@ -1,5 +1,7 @@
 const std = @import("std");
 const expect = std.testing.expect;
+const expectEqual = std.testing.expectEqual;
+const expectEqualSlices = std.testing.expectEqualSlices;
 const maxInt = std.math.maxInt;
 const minInt = std.math.minInt;
 
@@ -498,3 +500,18 @@ test "comptime_int param and return" {
 fn comptimeAdd(comptime a: comptime_int, comptime b: comptime_int) comptime_int {
     return a + b;
 }
+
+test "vector integer addition" {
+    const S = struct {
+        fn doTheTest() void {
+            var a: @Vector(4, i32) = []i32{ 1, 2, 3, 4 };
+            var b: @Vector(4, i32) = []i32{ 5, 6, 7, 8 };
+            var result = a + b;
+            var result_array: [4]i32 = result;
+            const expected = []i32{ 6, 8, 10, 12 };
+            expectEqualSlices(i32, &expected, &result_array);
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}
test/runtime_safety.zig
@@ -94,6 +94,20 @@ pub fn addCases(cases: *tests.CompareOutputContext) void {
         \\}
     );
 
+    cases.addRuntimeSafety("vector integer addition overflow",
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    @import("std").os.exit(126);
+        \\}
+        \\pub fn main() void {
+        \\    var a: @Vector(4, i32) = []i32{ 1, 2, 2147483643, 4 };
+        \\    var b: @Vector(4, i32) = []i32{ 5, 6, 7, 8 };
+        \\    const x = add(a, b);
+        \\}
+        \\fn add(a: @Vector(4, i32), b: @Vector(4, i32)) @Vector(4, i32) {
+        \\    return a + b;
+        \\}
+    );
+
     cases.addRuntimeSafety("integer subtraction overflow",
         \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
         \\    @import("std").os.exit(126);