Commit be0f656c21

Andrew Kelley <andrew@ziglang.org>
2019-04-04 21:44:19
fix `@divFloor` returning incorrect value and add `__modti3`
Closes #2152 See #1290
1 parent 12c4ab3
Changed files (7)
src/codegen.cpp
@@ -2455,8 +2455,8 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
         } else {
             zig_unreachable();
         }
-        LLVMBasicBlockRef div_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroOk");
         LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail");
+        LLVMBasicBlockRef div_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroOk");
         LLVMBuildCondBr(g->builder, is_zero_bit, div_zero_fail_block, div_zero_ok_block);
 
         LLVMPositionBuilderAtEnd(g->builder, div_zero_fail_block);
@@ -2469,8 +2469,8 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
             BigInt int_min_bi = {0};
             eval_min_max_value_int(g, type_entry, &int_min_bi, false);
             LLVMValueRef int_min_value = bigint_to_llvm_const(get_llvm_type(g, type_entry), &int_min_bi);
-            LLVMBasicBlockRef overflow_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivOverflowOk");
             LLVMBasicBlockRef overflow_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivOverflowFail");
+            LLVMBasicBlockRef overflow_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivOverflowOk");
             LLVMValueRef num_is_int_min = LLVMBuildICmp(g->builder, LLVMIntEQ, val1, int_min_value, "");
             LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, "");
             LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, "");
@@ -2574,20 +2574,19 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
                 if (!type_entry->data.integral.is_signed) {
                     return LLVMBuildUDiv(g->builder, val1, val2, "");
                 }
-                // const result = @divTrunc(a, b);
-                // if (result >= 0 or result * b == a)
-                //     return result;
-                // else
-                //     return result - 1;
-
-                LLVMValueRef result = LLVMBuildSDiv(g->builder, val1, val2, "");
-                LLVMValueRef is_pos = LLVMBuildICmp(g->builder, LLVMIntSGE, result, zero, "");
-                LLVMValueRef orig_num = LLVMBuildNSWMul(g->builder, result, val2, "");
-                LLVMValueRef orig_ok = LLVMBuildICmp(g->builder, LLVMIntEQ, orig_num, val1, "");
-                LLVMValueRef ok_bit = LLVMBuildOr(g->builder, orig_ok, is_pos, "");
-                LLVMValueRef one = LLVMConstInt(get_llvm_type(g, type_entry), 1, true);
-                LLVMValueRef result_minus_1 = LLVMBuildNSWSub(g->builder, result, one, "");
-                return LLVMBuildSelect(g->builder, ok_bit, result, result_minus_1, "");
+                // const d = @divTrunc(a, b);
+                // const r = @rem(a, b);
+                // return if (r == 0) d else d - ((a < 0) ^ (b < 0));
+
+                LLVMValueRef div_trunc = LLVMBuildSDiv(g->builder, val1, val2, "");
+                LLVMValueRef rem = LLVMBuildSRem(g->builder, val1, val2, "");
+                LLVMValueRef rem_eq_0 = LLVMBuildICmp(g->builder, LLVMIntEQ, rem, zero, "");
+                LLVMValueRef a_lt_0 = LLVMBuildICmp(g->builder, LLVMIntSLT, val1, zero, "");
+                LLVMValueRef b_lt_0 = LLVMBuildICmp(g->builder, LLVMIntSLT, val2, zero, "");
+                LLVMValueRef a_b_xor = LLVMBuildXor(g->builder, a_lt_0, b_lt_0, "");
+                LLVMValueRef a_b_xor_ext = LLVMBuildZExt(g->builder, a_b_xor, LLVMTypeOf(div_trunc), "");
+                LLVMValueRef d_sub_xor = LLVMBuildSub(g->builder, div_trunc, a_b_xor_ext, "");
+                return LLVMBuildSelect(g->builder, rem_eq_0, div_trunc, d_sub_xor, "");
             }
     }
     zig_unreachable();
std/special/compiler_rt/modti3.zig
@@ -0,0 +1,30 @@
+// Ported from:
+//
+// https://github.com/llvm/llvm-project/blob/2ffb1b0413efa9a24eb3c49e710e36f92e2cb50b/compiler-rt/lib/builtins/modti3.c
+
+const udivmod = @import("udivmod.zig").udivmod;
+const builtin = @import("builtin");
+const compiler_rt = @import("../compiler_rt.zig");
+
+pub extern fn __modti3(a: i128, b: i128) i128 {
+    @setRuntimeSafety(builtin.is_test);
+
+    const s_a = a >> (i128.bit_count - 1); // s = a < 0 ? -1 : 0
+    const s_b = b >> (i128.bit_count - 1); // s = b < 0 ? -1 : 0
+
+    const an = (a ^ s_a) -% s_a; // negate if s == -1
+    const bn = (b ^ s_b) -% s_b; // negate if s == -1
+
+    var r: u128 = undefined;
+    _ = udivmod(u128, @bitCast(u128, an), @bitCast(u128, bn), &r);
+    return (@bitCast(i128, r) ^ s_a) -% s_a; // negate if s == -1
+}
+
+pub extern fn __modti3_windows_x86_64(a: *const i128, b: *const i128) void {
+    @setRuntimeSafety(builtin.is_test);
+    compiler_rt.setXmm0(i128, __modti3(a.*, b.*));
+}
+
+test "import modti3" {
+    _ = @import("modti3_test.zig");
+}
std/special/compiler_rt/modti3_test.zig
@@ -0,0 +1,37 @@
+const __modti3 = @import("modti3.zig").__modti3;
+const testing = @import("std").testing;
+
+fn test__modti3(a: i128, b: i128, expected: i128) void {
+    const x = __modti3(a, b);
+    testing.expect(x == expected);
+}
+
+test "modti3" {
+    test__modti3(0, 1, 0);
+    test__modti3(0, -1, 0);
+    test__modti3(5, 3, 2);
+    test__modti3(5, -3, 2);
+    test__modti3(-5, 3, -2);
+    test__modti3(-5, -3, -2);
+
+    test__modti3(0x8000000000000000, 1, 0x0);
+    test__modti3(0x8000000000000000, -1, 0x0);
+    test__modti3(0x8000000000000000, 2, 0x0);
+    test__modti3(0x8000000000000000, -2, 0x0);
+    test__modti3(0x8000000000000000, 3, 2);
+    test__modti3(0x8000000000000000, -3, 2);
+
+    test__modti3(make_ti(0x8000000000000000, 0), 1, 0x0);
+    test__modti3(make_ti(0x8000000000000000, 0), -1, 0x0);
+    test__modti3(make_ti(0x8000000000000000, 0), 2, 0x0);
+    test__modti3(make_ti(0x8000000000000000, 0), -2, 0x0);
+    test__modti3(make_ti(0x8000000000000000, 0), 3, -2);
+    test__modti3(make_ti(0x8000000000000000, 0), -3, -2);
+}
+
+fn make_ti(high: u64, low: u64) i128 {
+    var result: u128 = high;
+    result <<= 64;
+    result |= low;
+    return @bitCast(i128, result);
+}
std/special/compiler_rt/mulXf3.zig
@@ -17,6 +17,7 @@ pub extern fn __mulsf3(a: f32, b: f32) f32 {
 }
 
 fn mulXf3(comptime T: type, a: T, b: T) T {
+    @setRuntimeSafety(builtin.is_test);
     const Z = @IntType(false, T.bit_count);
 
     const typeWidth = T.bit_count;
@@ -145,6 +146,7 @@ fn mulXf3(comptime T: type, a: T, b: T) T {
 }
 
 fn wideMultiply(comptime Z: type, a: Z, b: Z, hi: *Z, lo: *Z) void {
+    @setRuntimeSafety(builtin.is_test);
     switch (Z) {
         u32 => {
             // 32x32 --> 64 bit multiply
@@ -253,6 +255,7 @@ fn wideMultiply(comptime Z: type, a: Z, b: Z, hi: *Z, lo: *Z) void {
 }
 
 fn normalize(comptime T: type, significand: *@IntType(false, T.bit_count)) i32 {
+    @setRuntimeSafety(builtin.is_test);
     const Z = @IntType(false, T.bit_count);
     const significandBits = std.math.floatMantissaBits(T);
     const implicitBit = Z(1) << significandBits;
@@ -263,6 +266,7 @@ fn normalize(comptime T: type, significand: *@IntType(false, T.bit_count)) i32 {
 }
 
 fn wideRightShiftWithSticky(comptime Z: type, hi: *Z, lo: *Z, count: u32) void {
+    @setRuntimeSafety(builtin.is_test);
     const typeWidth = Z.bit_count;
     const S = std.math.Log2Int(Z);
     if (count < typeWidth) {
std/special/compiler_rt.zig
@@ -159,6 +159,7 @@ comptime {
                     @export("___chkstk_ms", ___chkstk_ms, linkage);
                 }
                 @export("__divti3", @import("compiler_rt/divti3.zig").__divti3_windows_x86_64, linkage);
+                @export("__modti3", @import("compiler_rt/modti3.zig").__modti3_windows_x86_64, linkage);
                 @export("__multi3", @import("compiler_rt/multi3.zig").__multi3_windows_x86_64, linkage);
                 @export("__muloti4", @import("compiler_rt/muloti4.zig").__muloti4_windows_x86_64, linkage);
                 @export("__udivti3", @import("compiler_rt/udivti3.zig").__udivti3_windows_x86_64, linkage);
@@ -169,6 +170,7 @@ comptime {
         }
     } else {
         @export("__divti3", @import("compiler_rt/divti3.zig").__divti3, linkage);
+        @export("__modti3", @import("compiler_rt/modti3.zig").__modti3, linkage);
         @export("__multi3", @import("compiler_rt/multi3.zig").__multi3, linkage);
         @export("__muloti4", @import("compiler_rt/muloti4.zig").__muloti4, linkage);
         @export("__udivti3", @import("compiler_rt/udivti3.zig").__udivti3, linkage);
test/stage1/behavior/math.zig
@@ -31,6 +31,9 @@ fn testDivision() void {
     expect(divFloor(i32, 0, -0x80000000) == 0);
     expect(divFloor(i32, -0x40000001, 0x40000000) == -2);
     expect(divFloor(i32, -0x80000000, 1) == -0x80000000);
+    expect(divFloor(i32, 10, 12) == 0);
+    expect(divFloor(i32, -14, 12) == -2);
+    expect(divFloor(i32, -2, 12) == -1);
 
     expect(divTrunc(i32, 5, 3) == 1);
     expect(divTrunc(i32, -5, 3) == -1);
@@ -40,6 +43,13 @@ fn testDivision() void {
     expect(divTrunc(f32, -5.0, 3.0) == -1.0);
     expect(divTrunc(f64, 5.0, 3.0) == 1.0);
     expect(divTrunc(f64, -5.0, 3.0) == -1.0);
+    expect(divTrunc(i32, 10, 12) == 0);
+    expect(divTrunc(i32, -14, 12) == -1);
+    expect(divTrunc(i32, -2, 12) == 0);
+
+    expect(mod(i32, 10, 12) == 10);
+    expect(mod(i32, -14, 12) == 10);
+    expect(mod(i32, -2, 12) == 10);
 
     comptime {
         expect(
@@ -77,6 +87,9 @@ fn divFloor(comptime T: type, a: T, b: T) T {
 fn divTrunc(comptime T: type, a: T, b: T) T {
     return @divTrunc(a, b);
 }
+fn mod(comptime T: type, a: T, b: T) T {
+    return @mod(a, b);
+}
 
 test "@addWithOverflow" {
     var result: u8 = undefined;
CMakeLists.txt
@@ -662,8 +662,9 @@ set(ZIG_STD_FILES
     "special/compiler_rt/floatuntidf.zig"
     "special/compiler_rt/floatuntisf.zig"
     "special/compiler_rt/floatuntitf.zig"
-    "special/compiler_rt/muloti4.zig"
+    "special/compiler_rt/modti3.zig"
     "special/compiler_rt/mulXf3.zig"
+    "special/compiler_rt/muloti4.zig"
     "special/compiler_rt/multi3.zig"
     "special/compiler_rt/negXf2.zig"
     "special/compiler_rt/popcountdi2.zig"