Commit b5f8fb85e6

Mateusz Radomski <33978857+m-radomski@users.noreply.github.com>
2022-02-13 14:37:38
Implement f128 `@rem`
1 parent f22443b
Changed files (6)
lib
src
test
behavior
lib/std/special/compiler_rt/floatfmodl.zig
@@ -0,0 +1,126 @@
+const builtin = @import("builtin");
+const std = @import("std");
+
+// fmodl - floating modulo large, returns the remainder of division for f128 types
+// Logic and flow heavily inspired by MUSL fmodl for 113 mantissa digits
+pub fn fmodl(a: f128, b: f128) callconv(.C) f128 {
+    @setRuntimeSafety(builtin.is_test);
+    var amod = a;
+    var bmod = b;
+    const aPtr_u64 = @ptrCast([*]u64, &amod);
+    const bPtr_u64 = @ptrCast([*]u64, &bmod);
+    const aPtr_u16 = @ptrCast([*]u16, &amod);
+    const bPtr_u16 = @ptrCast([*]u16, &bmod);
+
+    const exp_and_sign_index = comptime switch (builtin.target.cpu.arch.endian()) {
+        .Little => 7,
+        .Big => 0,
+    };
+    const low_index = comptime switch (builtin.target.cpu.arch.endian()) {
+        .Little => 0,
+        .Big => 1,
+    };
+    const high_index = comptime switch (builtin.target.cpu.arch.endian()) {
+        .Little => 1,
+        .Big => 0,
+    };
+
+    const signA = aPtr_u16[exp_and_sign_index] & 0x8000;
+    var expA = @intCast(i32, (aPtr_u16[exp_and_sign_index] & 0x7fff));
+    var expB = bPtr_u16[exp_and_sign_index] & 0x7fff;
+
+    // There are 3 cases where the answer is undefined, check for:
+    //   - fmodl(val, 0)
+    //   - fmodl(val, NaN)
+    //   - fmodl(inf, val)
+    // The sign on checked values does not matter.
+    // Doing (a * b) / (a * b) procudes undefined results
+    // because the three cases always produce undefined calculations:
+    //   - 0 / 0
+    //   - val * NaN
+    //   - inf / inf
+    if (b == 0 or std.math.isNan(b) or expA == 0x7fff) {
+        return (a * b) / (a * b);
+    }
+
+    // Remove the sign from both
+    aPtr_u16[exp_and_sign_index] = @bitCast(u16, @intCast(i16, expA));
+    bPtr_u16[exp_and_sign_index] = @bitCast(u16, @intCast(i16, expB));
+    if (amod <= bmod) {
+        if (amod == bmod) {
+            return 0 * a;
+        }
+        return a;
+    }
+
+    if (expA == 0) {
+        amod *= 0x1p120;
+        expA = aPtr_u16[exp_and_sign_index] -% 120;
+    }
+
+    if (expB == 0) {
+        bmod *= 0x1p120;
+        expB = bPtr_u16[exp_and_sign_index] -% 120;
+    }
+
+    // OR in extra non-stored mantissa digit
+    var highA: u64 = (aPtr_u64[high_index] & (std.math.maxInt(u64) >> 16)) | 1 << 48;
+    var highB: u64 = (bPtr_u64[high_index] & (std.math.maxInt(u64) >> 16)) | 1 << 48;
+    var lowA: u64 = aPtr_u64[low_index];
+    var lowB: u64 = bPtr_u64[low_index];
+
+    while (expA > expB) : (expA -= 1) {
+        var high = highA -% highB;
+        var low = lowA -% lowB;
+        if (lowA < lowB) {
+            high = highA -% 1;
+        }
+        if (high >> 63 == 0) {
+            if ((high | low) == 0) {
+                return 0 * a;
+            }
+            highA = 2 *% high + (low >> 63);
+            lowA = 2 *% low;
+        } else {
+            highA = 2 *% highA + (lowA >> 63);
+            lowA = 2 *% lowA;
+        }
+    }
+
+    var high = highA -% highB;
+    var low = lowA -% lowB;
+    if (lowA < lowB) {
+        high -= 1;
+    }
+    if (high >> 63 == 0) {
+        if ((high | low) == 0) {
+            return 0 * a;
+        }
+        highA = high;
+        lowA = low;
+    }
+
+    while (highA >> 48 == 0) {
+        highA = 2 *% highA + (lowA >> 63);
+        lowA = 2 *% lowA;
+        expA = expA - 1;
+    }
+
+    // Overwrite the current amod with the values in highA and lowA
+    aPtr_u64[high_index] = highA;
+    aPtr_u64[low_index] = lowA;
+
+    // Combine the exponent with the sign, normalize if happend to be denormalized
+    if (expA <= 0) {
+        aPtr_u16[exp_and_sign_index] = @truncate(u16, @bitCast(u32, (expA +% 120))) | signA;
+        amod *= 0x1p-120;
+    } else {
+        aPtr_u16[exp_and_sign_index] = @truncate(u16, @bitCast(u32, expA)) | signA;
+    }
+
+    return amod;
+}
+
+test {
+    _ = @import("floatfmodl_test.zig");
+}
lib/std/special/compiler_rt/floatfmodl_test.zig
@@ -0,0 +1,46 @@
+const std = @import("std");
+const fmodl = @import("floatfmodl.zig");
+const testing = std.testing;
+
+fn test_fmodl(a: f128, b: f128, exp: f128) !void {
+    const res = fmodl.fmodl(a, b);
+    try testing.expect(exp == res);
+}
+
+fn test_fmodl_nans() !void {
+    try testing.expect(std.math.isNan(fmodl.fmodl(1.0, std.math.nan_f128)));
+    try testing.expect(std.math.isNan(fmodl.fmodl(1.0, -std.math.nan_f128)));
+    try testing.expect(std.math.isNan(fmodl.fmodl(std.math.nan_f128, 1.0)));
+    try testing.expect(std.math.isNan(fmodl.fmodl(-std.math.nan_f128, 1.0)));
+}
+
+fn test_fmodl_infs() !void {
+    try testing.expect(fmodl.fmodl(1.0, std.math.inf_f128) == 1.0);
+    try testing.expect(fmodl.fmodl(1.0, -std.math.inf_f128) == 1.0);
+    try testing.expect(std.math.isNan(fmodl.fmodl(std.math.inf_f128, 1.0)));
+    try testing.expect(std.math.isNan(fmodl.fmodl(-std.math.inf_f128, 1.0)));
+}
+
+test "fmodl" {
+    try test_fmodl(6.8, 4.0, 2.8);
+    try test_fmodl(6.8, -4.0, 2.8);
+    try test_fmodl(-6.8, 4.0, -2.8);
+    try test_fmodl(-6.8, -4.0, -2.8);
+    try test_fmodl(3.0, 2.0, 1.0);
+    try test_fmodl(-5.0, 3.0, -2.0);
+    try test_fmodl(3.0, 2.0, 1.0);
+    try test_fmodl(1.0, 2.0, 1.0);
+    try test_fmodl(0.0, 1.0, 0.0);
+    try test_fmodl(-0.0, 1.0, -0.0);
+    try test_fmodl(7046119.0, 5558362.0, 1487757.0);
+    try test_fmodl(9010357.0, 1957236.0, 1181413.0);
+
+    // Denormals
+    const a: f128 = 0xedcb34a235253948765432134674p-16494;
+    const b: f128 = 0x5d2e38791cfbc0737402da5a9518p-16494;
+    const exp: f128 = 0x336ec3affb2db8618e4e7d5e1c44p-16494;
+    try test_fmodl(a, b, exp);
+
+    try test_fmodl_nans();
+    try test_fmodl_infs();
+}
lib/std/special/compiler_rt.zig
@@ -759,6 +759,9 @@ comptime {
         @export(__unordtf2, .{ .name = "__unordkf2", .linkage = linkage });
     }
 
+    const fmodl = @import("compiler_rt/floatfmodl.zig").fmodl;
+    @export(fmodl, .{ .name = "fmodl", .linkage = linkage });
+
     @export(floorf, .{ .name = "floorf", .linkage = linkage });
     @export(floor, .{ .name = "floor", .linkage = linkage });
     @export(floorl, .{ .name = "floorl", .linkage = linkage });
src/stage1/ir.cpp
@@ -3338,6 +3338,32 @@ static void float_div_floor(ZigValue *out_val, ZigValue *op1, ZigValue *op2) {
     }
 }
 
+// c = a - b * trunc(a / b)
+static float16_t zig_f16_rem(float16_t a, float16_t b) {
+    float16_t c;
+    c = f16_div(a, b);
+    c = f16_roundToInt(c, softfloat_round_minMag, false);
+    c = f16_mul(b, c);
+    c = f16_sub(a, c);
+    return c;
+}
+
+// c = a - b * trunc(a / b)
+static void zig_f128M_rem(const float128_t* a, const float128_t* b, float128_t* c) {
+    f128M_div(a, b, c);
+    f128M_roundToInt(c, softfloat_round_minMag, false, c);
+    f128M_mul(b, c, c);
+    f128M_sub(a, c, c);
+}
+
+// c = a - b * trunc(a / b)
+static void zig_extF80M_rem(const extFloat80_t* a, const extFloat80_t* b, extFloat80_t* c) {
+    extF80M_div(a, b, c);
+    extF80M_roundToInt(c, softfloat_round_minMag, false, c);
+    extF80M_mul(b, c, c);
+    extF80M_sub(a, c, c);
+}
+
 static void float_rem(ZigValue *out_val, ZigValue *op1, ZigValue *op2) {
     assert(op1->type == op2->type);
     out_val->type = op1->type;
@@ -3346,7 +3372,7 @@ static void float_rem(ZigValue *out_val, ZigValue *op1, ZigValue *op2) {
     } else if (op1->type->id == ZigTypeIdFloat) {
         switch (op1->type->data.floating.bit_count) {
             case 16:
-                out_val->data.x_f16 = f16_rem(op1->data.x_f16, op2->data.x_f16);
+                out_val->data.x_f16 = zig_f16_rem(op1->data.x_f16, op2->data.x_f16);
                 return;
             case 32:
                 out_val->data.x_f32 = fmodf(op1->data.x_f32, op2->data.x_f32);
@@ -3355,10 +3381,10 @@ static void float_rem(ZigValue *out_val, ZigValue *op1, ZigValue *op2) {
                 out_val->data.x_f64 = fmod(op1->data.x_f64, op2->data.x_f64);
                 return;
             case 80:
-                extF80M_rem(&op1->data.x_f80, &op2->data.x_f80, &out_val->data.x_f80);
+                zig_extF80M_rem(&op1->data.x_f80, &op2->data.x_f80, &out_val->data.x_f80);
                 return;
             case 128:
-                f128M_rem(&op1->data.x_f128, &op2->data.x_f128, &out_val->data.x_f128);
+                zig_f128M_rem(&op1->data.x_f128, &op2->data.x_f128, &out_val->data.x_f128);
                 return;
             default:
                 zig_unreachable();
src/value.zig
@@ -1482,8 +1482,7 @@ pub const Value = extern union {
             .float_64 => @rem(self.castTag(.float_64).?.data, 1) != 0,
             //.float_80 => @rem(self.castTag(.float_80).?.data, 1) != 0,
             .float_80 => @panic("TODO implement __remx in compiler-rt"),
-            //.float_128 => @rem(self.castTag(.float_128).?.data, 1) != 0,
-            .float_128 => @panic("TODO implement fmodl in compiler-rt"),
+            .float_128 => @rem(self.castTag(.float_128).?.data, 1) != 0,
 
             else => unreachable,
         };
@@ -2888,9 +2887,6 @@ pub const Value = extern union {
                 return Value.Tag.float_80.create(arena, @rem(lhs_val, rhs_val));
             },
             128 => {
-                if (true) {
-                    @panic("TODO implement compiler_rt fmodl");
-                }
                 const lhs_val = lhs.toFloat(f128);
                 const rhs_val = rhs.toFloat(f128);
                 return Value.Tag.float_128.create(arena, @rem(lhs_val, rhs_val));
@@ -2925,9 +2921,6 @@ pub const Value = extern union {
                 return Value.Tag.float_80.create(arena, @mod(lhs_val, rhs_val));
             },
             128 => {
-                if (true) {
-                    @panic("TODO implement compiler_rt fmodl");
-                }
                 const lhs_val = lhs.toFloat(f128);
                 const rhs_val = rhs.toFloat(f128);
                 return Value.Tag.float_128.create(arena, @mod(lhs_val, rhs_val));
test/behavior/math.zig
@@ -782,8 +782,6 @@ test "comptime float rem int" {
 }
 
 test "remainder division" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
-
     comptime try remdiv(f16);
     comptime try remdiv(f32);
     comptime try remdiv(f64);
@@ -798,6 +796,64 @@ fn remdiv(comptime T: type) !void {
     try expect(@as(T, 1) == @as(T, 7) % @as(T, 3));
 }
 
+test "float remainder division using @rem" {
+    comptime try frem(f16);
+    comptime try frem(f32);
+    comptime try frem(f64);
+    comptime try frem(f128);
+    try frem(f16);
+    try frem(f32);
+    try frem(f64);
+    try frem(f128);
+}
+
+fn frem(comptime T: type) !void {
+    const epsilon = switch (T) {
+        f16 => 1.0,
+        f32 => 0.001,
+        f64 => 0.00001,
+        f128 => 0.0000001,
+        else => unreachable,
+    };
+
+    try expect(std.math.fabs(@rem(@as(T, 6.9), @as(T, 4.0)) - @as(T, 2.9)) < epsilon);
+    try expect(std.math.fabs(@rem(@as(T, -6.9), @as(T, 4.0)) - @as(T, -2.9)) < epsilon);
+    try expect(std.math.fabs(@rem(@as(T, -5.0), @as(T, 3.0)) - @as(T, -2.0)) < epsilon);
+    try expect(std.math.fabs(@rem(@as(T, 3.0), @as(T, 2.0)) - @as(T, 1.0)) < epsilon);
+    try expect(std.math.fabs(@rem(@as(T, 1.0), @as(T, 2.0)) - @as(T, 1.0)) < epsilon);
+    try expect(std.math.fabs(@rem(@as(T, 0.0), @as(T, 1.0)) - @as(T, 0.0)) < epsilon);
+    try expect(std.math.fabs(@rem(@as(T, -0.0), @as(T, 1.0)) - @as(T, -0.0)) < epsilon);
+}
+
+test "float modulo division using @mod" {
+    comptime try fmod(f16);
+    comptime try fmod(f32);
+    comptime try fmod(f64);
+    comptime try fmod(f128);
+    try fmod(f16);
+    try fmod(f32);
+    try fmod(f64);
+    try fmod(f128);
+}
+
+fn fmod(comptime T: type) !void {
+    const epsilon = switch (T) {
+        f16 => 1.0,
+        f32 => 0.001,
+        f64 => 0.00001,
+        f128 => 0.0000001,
+        else => unreachable,
+    };
+
+    try expect(std.math.fabs(@mod(@as(T, 6.9), @as(T, 4.0)) - @as(T, 2.9)) < epsilon);
+    try expect(std.math.fabs(@mod(@as(T, -6.9), @as(T, 4.0)) - @as(T, 1.1)) < epsilon);
+    try expect(std.math.fabs(@mod(@as(T, -5.0), @as(T, 3.0)) - @as(T, 1.0)) < epsilon);
+    try expect(std.math.fabs(@mod(@as(T, 3.0), @as(T, 2.0)) - @as(T, 1.0)) < epsilon);
+    try expect(std.math.fabs(@mod(@as(T, 1.0), @as(T, 2.0)) - @as(T, 1.0)) < epsilon);
+    try expect(std.math.fabs(@mod(@as(T, 0.0), @as(T, 1.0)) - @as(T, 0.0)) < epsilon);
+    try expect(std.math.fabs(@mod(@as(T, -0.0), @as(T, 1.0)) - @as(T, -0.0)) < epsilon);
+}
+
 test "@sqrt" {
     try testSqrt(f64, 12.0);
     comptime try testSqrt(f64, 12.0);