Commit d760cae2b1

Cody Tapscott <topolarity@tapscott.me>
2022-04-19 05:22:43
compiler_rt: implement __mulxf3 for f80
1 parent 5195b87
Changed files (3)
lib
lib/std/special/compiler_rt/mulXf3.zig
@@ -3,12 +3,16 @@
 // https://github.com/llvm/llvm-project/blob/2ffb1b0413efa9a24eb3c49e710e36f92e2cb50b/compiler-rt/lib/builtins/fp_mul_impl.inc
 
 const std = @import("std");
+const math = std.math;
 const builtin = @import("builtin");
 const compiler_rt = @import("../compiler_rt.zig");
 
 pub fn __multf3(a: f128, b: f128) callconv(.C) f128 {
     return mulXf3(f128, a, b);
 }
+pub fn __mulxf3(a: f80, b: f80) callconv(.C) f80 {
+    return mulXf3(f80, a, b);
+}
 pub fn __muldf3(a: f64, b: f64) callconv(.C) f64 {
     return mulXf3(f64, a, b);
 }
@@ -29,30 +33,36 @@ pub fn __aeabi_dmul(a: f64, b: f64) callconv(.C) f64 {
 fn mulXf3(comptime T: type, a: T, b: T) T {
     @setRuntimeSafety(builtin.is_test);
     const typeWidth = @typeInfo(T).Float.bits;
+    const significandBits = math.floatMantissaBits(T);
+    const fractionalBits = math.floatFractionalBits(T);
+    const exponentBits = math.floatExponentBits(T);
+
     const Z = std.meta.Int(.unsigned, typeWidth);
 
-    const significandBits = std.math.floatMantissaBits(T);
-    const exponentBits = std.math.floatExponentBits(T);
+    // ZSignificand is large enough to contain the significand, including an explicit integer bit
+    const ZSignificand = PowerOfTwoSignificandZ(T);
+    const ZSignificandBits = @typeInfo(ZSignificand).Int.bits;
 
+    const roundBit = (1 << (ZSignificandBits - 1));
     const signBit = (@as(Z, 1) << (significandBits + exponentBits));
     const maxExponent = ((1 << exponentBits) - 1);
     const exponentBias = (maxExponent >> 1);
 
-    const implicitBit = (@as(Z, 1) << significandBits);
-    const quietBit = implicitBit >> 1;
-    const significandMask = implicitBit - 1;
+    const integerBit = (@as(ZSignificand, 1) << fractionalBits);
+    const quietBit = integerBit >> 1;
+    const significandMask = (@as(Z, 1) << significandBits) - 1;
 
     const absMask = signBit - 1;
-    const exponentMask = absMask ^ significandMask;
-    const qnanRep = exponentMask | quietBit;
-    const infRep = @bitCast(Z, std.math.inf(T));
+    const qnanRep = @bitCast(Z, math.nan(T)) | quietBit;
+    const infRep = @bitCast(Z, math.inf(T));
+    const minNormalRep = @bitCast(Z, math.floatMin(T));
 
     const aExponent = @truncate(u32, (@bitCast(Z, a) >> significandBits) & maxExponent);
     const bExponent = @truncate(u32, (@bitCast(Z, b) >> significandBits) & maxExponent);
     const productSign: Z = (@bitCast(Z, a) ^ @bitCast(Z, b)) & signBit;
 
-    var aSignificand: Z = @bitCast(Z, a) & significandMask;
-    var bSignificand: Z = @bitCast(Z, b) & significandMask;
+    var aSignificand: ZSignificand = @intCast(ZSignificand, @bitCast(Z, a) & significandMask);
+    var bSignificand: ZSignificand = @intCast(ZSignificand, @bitCast(Z, b) & significandMask);
     var scale: i32 = 0;
 
     // Detect if a or b is zero, denormal, infinity, or NaN.
@@ -93,38 +103,40 @@ fn mulXf3(comptime T: type, a: T, b: T) T {
         // one or both of a or b is denormal, the other (if applicable) is a
         // normal number.  Renormalize one or both of a and b, and set scale to
         // include the necessary exponent adjustment.
-        if (aAbs < implicitBit) scale += normalize(T, &aSignificand);
-        if (bAbs < implicitBit) scale += normalize(T, &bSignificand);
+        if (aAbs < minNormalRep) scale += normalize(T, &aSignificand);
+        if (bAbs < minNormalRep) scale += normalize(T, &bSignificand);
     }
 
     // Or in the implicit significand bit.  (If we fell through from the
     // denormal path it was already set by normalize( ), but setting it twice
     // won't hurt anything.)
-    aSignificand |= implicitBit;
-    bSignificand |= implicitBit;
+    aSignificand |= integerBit;
+    bSignificand |= integerBit;
 
     // Get the significand of a*b.  Before multiplying the significands, shift
     // one of them left to left-align it in the field.  Thus, the product will
     // have (exponentBits + 2) integral digits, all but two of which must be
     // zero.  Normalizing this result is just a conditional left-shift by one
     // and bumping the exponent accordingly.
-    var productHi: Z = undefined;
-    var productLo: Z = undefined;
-    wideMultiply(Z, aSignificand, bSignificand << exponentBits, &productHi, &productLo);
+    var productHi: ZSignificand = undefined;
+    var productLo: ZSignificand = undefined;
+    const left_align_shift = ZSignificandBits - fractionalBits - 1;
+    wideMultiply(ZSignificand, aSignificand, bSignificand << left_align_shift, &productHi, &productLo);
 
-    var productExponent: i32 = @bitCast(i32, aExponent +% bExponent) -% exponentBias +% scale;
+    var productExponent: i32 = @intCast(i32, aExponent + bExponent) - exponentBias + scale;
 
     // Normalize the significand, adjust exponent if needed.
-    if ((productHi & implicitBit) != 0) {
+    if ((productHi & integerBit) != 0) {
         productExponent +%= 1;
     } else {
-        productHi = (productHi << 1) | (productLo >> (typeWidth - 1));
+        productHi = (productHi << 1) | (productLo >> (ZSignificandBits - 1));
         productLo = productLo << 1;
     }
 
     // If we have overflowed the type, return +/- infinity.
     if (productExponent >= maxExponent) return @bitCast(T, infRep | productSign);
 
+    var result: Z = undefined;
     if (productExponent <= 0) {
         // Result is denormal before rounding
         //
@@ -133,35 +145,49 @@ fn mulXf3(comptime T: type, a: T, b: T) T {
         // handle this case separately, but we make it a special case to
         // simplify the shift logic.
         const shift: u32 = @truncate(u32, @as(Z, 1) -% @bitCast(u32, productExponent));
-        if (shift >= typeWidth) return @bitCast(T, productSign);
+        if (shift >= ZSignificandBits) return @bitCast(T, productSign);
 
         // Otherwise, shift the significand of the result so that the round
         // bit is the high bit of productLo.
-        wideRightShiftWithSticky(Z, &productHi, &productLo, shift);
+        const sticky = wideShrWithTruncation(ZSignificand, &productHi, &productLo, shift);
+        productLo |= @boolToInt(sticky);
+        result = productHi;
     } else {
         // Result is normal before rounding; insert the exponent.
-        productHi &= significandMask;
-        productHi |= @as(Z, @bitCast(u32, productExponent)) << significandBits;
+        result = productHi & significandMask;
+        result |= @intCast(Z, productExponent) << significandBits;
     }
 
-    // Insert the sign of the result:
-    productHi |= productSign;
-
     // Final rounding.  The final result may overflow to infinity, or underflow
     // to zero, but those are the correct results in those cases.  We use the
     // default IEEE-754 round-to-nearest, ties-to-even rounding mode.
-    if (productLo > signBit) productHi +%= 1;
-    if (productLo == signBit) productHi +%= productHi & 1;
-    return @bitCast(T, productHi);
+    if (productLo > roundBit) result +%= 1;
+    if (productLo == roundBit) result +%= result & 1;
+
+    // Restore any explicit integer bit, if it was rounded off
+    if (significandBits != fractionalBits) {
+        if ((result >> significandBits) != 0) result |= integerBit;
+    }
+
+    // Insert the sign of the result:
+    result |= productSign;
+
+    return @bitCast(T, result);
 }
 
 fn wideMultiply(comptime Z: type, a: Z, b: Z, hi: *Z, lo: *Z) void {
     @setRuntimeSafety(builtin.is_test);
     switch (Z) {
+        u16 => {
+            // 16x16 --> 32 bit multiply
+            const product = @as(u32, a) * @as(u32, b);
+            hi.* = @intCast(u16, product >> 16);
+            lo.* = @truncate(u16, product);
+        },
         u32 => {
             // 32x32 --> 64 bit multiply
             const product = @as(u64, a) * @as(u64, b);
-            hi.* = @truncate(u32, product >> 32);
+            hi.* = @intCast(u32, product >> 32);
             lo.* = @truncate(u32, product);
         },
         u64 => {
@@ -170,7 +196,7 @@ fn wideMultiply(comptime Z: type, a: Z, b: Z, hi: *Z, lo: *Z) void {
                     return @truncate(u32, x);
                 }
                 fn hiWord(x: u64) u64 {
-                    return @truncate(u32, x >> 32);
+                    return @intCast(u32, x >> 32);
                 }
             };
             // 64x64 -> 128 wide multiply for platforms that don't have such an operation;
@@ -264,34 +290,45 @@ fn wideMultiply(comptime Z: type, a: Z, b: Z, hi: *Z, lo: *Z) void {
     }
 }
 
-fn normalize(comptime T: type, significand: *std.meta.Int(.unsigned, @typeInfo(T).Float.bits)) i32 {
+/// Returns a power-of-two integer type that is large enough to contain
+/// the significand of T, including an explicit integer bit
+fn PowerOfTwoSignificandZ(comptime T: type) type {
+    const bits = math.ceilPowerOfTwoAssert(u16, math.floatFractionalBits(T) + 1);
+    return std.meta.Int(.unsigned, bits);
+}
+
+fn normalize(comptime T: type, significand: *PowerOfTwoSignificandZ(T)) i32 {
     @setRuntimeSafety(builtin.is_test);
-    const Z = std.meta.Int(.unsigned, @typeInfo(T).Float.bits);
-    const significandBits = std.math.floatMantissaBits(T);
-    const implicitBit = @as(Z, 1) << significandBits;
+    const Z = PowerOfTwoSignificandZ(T);
+    const integerBit = @as(Z, 1) << math.floatFractionalBits(T);
 
-    const shift = @clz(Z, significand.*) - @clz(Z, implicitBit);
-    significand.* <<= @intCast(std.math.Log2Int(Z), shift);
+    const shift = @clz(Z, significand.*) - @clz(Z, integerBit);
+    significand.* <<= @intCast(math.Log2Int(Z), shift);
     return @as(i32, 1) - shift;
 }
 
-fn wideRightShiftWithSticky(comptime Z: type, hi: *Z, lo: *Z, count: u32) void {
+// Returns `true` if the right shift is inexact (i.e. any bit shifted out is non-zero)
+//
+// This is analogous to an shr version of `@shlWithOverflow`
+fn wideShrWithTruncation(comptime Z: type, hi: *Z, lo: *Z, count: u32) bool {
     @setRuntimeSafety(builtin.is_test);
     const typeWidth = @typeInfo(Z).Int.bits;
-    const S = std.math.Log2Int(Z);
+    const S = math.Log2Int(Z);
+    var inexact = false;
     if (count < typeWidth) {
-        const sticky = @boolToInt((lo.* << @intCast(S, typeWidth -% count)) != 0);
-        lo.* = (hi.* << @intCast(S, typeWidth -% count)) | (lo.* >> @intCast(S, count)) | sticky;
+        inexact = (lo.* << @intCast(S, typeWidth -% count)) != 0;
+        lo.* = (hi.* << @intCast(S, typeWidth -% count)) | (lo.* >> @intCast(S, count));
         hi.* = hi.* >> @intCast(S, count);
     } else if (count < 2 * typeWidth) {
-        const sticky = @boolToInt((hi.* << @intCast(S, 2 * typeWidth -% count) | lo.*) != 0);
-        lo.* = hi.* >> @intCast(S, count -% typeWidth) | sticky;
+        inexact = (hi.* << @intCast(S, 2 * typeWidth -% count) | lo.*) != 0;
+        lo.* = hi.* >> @intCast(S, count -% typeWidth);
         hi.* = 0;
     } else {
-        const sticky = @boolToInt((hi.* | lo.*) != 0);
-        lo.* = sticky;
+        inexact = (hi.* | lo.*) != 0;
+        lo.* = 0;
         hi.* = 0;
     }
+    return inexact;
 }
 
 test {
lib/std/special/compiler_rt/mulXf3_test.zig
@@ -2,10 +2,15 @@
 //
 // https://github.com/llvm/llvm-project/blob/2ffb1b0413efa9a24eb3c49e710e36f92e2cb50b/compiler-rt/test/builtins/Unit/multf3_test.c
 
+const std = @import("std");
+const math = std.math;
 const qnan128 = @bitCast(f128, @as(u128, 0x7fff800000000000) << 64);
 const inf128 = @bitCast(f128, @as(u128, 0x7fff000000000000) << 64);
 
 const __multf3 = @import("mulXf3.zig").__multf3;
+const __mulxf3 = @import("mulXf3.zig").__mulxf3;
+const __muldf3 = @import("mulXf3.zig").__muldf3;
+const __mulsf3 = @import("mulXf3.zig").__mulsf3;
 
 // return true if equal
 // use two 64-bit integers intead of one 128-bit integer
@@ -97,4 +102,66 @@ test "multf3" {
         0x3f90000000000000,
         0x0,
     );
+
+    try test__multf3(0x1.0000_0000_0000_0000_0000_0000_0001p+0, 0x1.8p+5, 0x4004_8000_0000_0000, 0x0000_0000_0000_0002);
+    try test__multf3(0x1.0000_0000_0000_0000_0000_0000_0002p+0, 0x1.8p+5, 0x4004_8000_0000_0000, 0x0000_0000_0000_0003);
+}
+
+const qnan80 = @bitCast(f80, @bitCast(u80, math.nan(f80)) | (1 << (math.floatFractionalBits(f80) - 1)));
+
+fn test__mulxf3(a: f80, b: f80, expected: u80) !void {
+    const x = __mulxf3(a, b);
+    const rep = @bitCast(u80, x);
+
+    if (rep == expected)
+        return;
+
+    if (math.isNan(@bitCast(f80, expected)) and math.isNan(x))
+        return; // We don't currently test NaN payload propagation
+
+    return error.TestFailed;
+}
+
+test "mulxf3" {
+    // NaN * any = NaN
+    try test__mulxf3(qnan80, 0x1.23456789abcdefp+5, @bitCast(u80, qnan80));
+    try test__mulxf3(@bitCast(f80, @as(u80, 0x7fff_8000_8000_3000_0000)), 0x1.23456789abcdefp+5, @bitCast(u80, qnan80));
+
+    // any * NaN = NaN
+    try test__mulxf3(0x1.23456789abcdefp+5, qnan80, @bitCast(u80, qnan80));
+    try test__mulxf3(0x1.23456789abcdefp+5, @bitCast(f80, @as(u80, 0x7fff_8000_8000_3000_0000)), @bitCast(u80, qnan80));
+
+    // NaN * inf = NaN
+    try test__mulxf3(qnan80, math.inf(f80), @bitCast(u80, qnan80));
+
+    // inf * NaN = NaN
+    try test__mulxf3(math.inf(f80), qnan80, @bitCast(u80, qnan80));
+
+    // inf * inf = inf
+    try test__mulxf3(math.inf(f80), math.inf(f80), @bitCast(u80, math.inf(f80)));
+
+    // inf * -inf = -inf
+    try test__mulxf3(math.inf(f80), -math.inf(f80), @bitCast(u80, -math.inf(f80)));
+
+    // -inf + inf = -inf
+    try test__mulxf3(-math.inf(f80), math.inf(f80), @bitCast(u80, -math.inf(f80)));
+
+    // inf * any = inf
+    try test__mulxf3(math.inf(f80), 0x1.2335653452436234723489432abcdefp+5, @bitCast(u80, math.inf(f80)));
+
+    // any * inf = inf
+    try test__mulxf3(0x1.2335653452436234723489432abcdefp+5, math.inf(f80), @bitCast(u80, math.inf(f80)));
+
+    // any * any
+    try test__mulxf3(0x1.0p+0, 0x1.dcba987654321p+5, 0x4004_ee5d_4c3b_2a19_0800);
+    try test__mulxf3(0x1.0000_0000_0000_0004p+0, 0x1.8p+5, 0x4004_C000_0000_0000_0003); // exact
+
+    try test__mulxf3(0x1.0000_0000_0000_0002p+0, 0x1.0p+5, 0x4004_8000_0000_0000_0001); // exact
+    try test__mulxf3(0x1.0000_0000_0000_0002p+0, 0x1.7ffep+5, 0x4004_BFFF_0000_0000_0001); // round down
+    try test__mulxf3(0x1.0000_0000_0000_0002p+0, 0x1.8p+5, 0x4004_C000_0000_0000_0002); // round up to even
+    try test__mulxf3(0x1.0000_0000_0000_0002p+0, 0x1.8002p+5, 0x4004_C001_0000_0000_0002); // round up
+    try test__mulxf3(0x1.0000_0000_0000_0002p+0, 0x1.0p+6, 0x4005_8000_0000_0000_0001); // exact
+
+    try test__mulxf3(0x1.0000_0001p+0, 0x1.0000_0001p+0, 0x3FFF_8000_0001_0000_0000); // round down to even
+    try test__mulxf3(0x1.0000_0001p+0, 0x1.0000_0001_0002p+0, 0x3FFF_8000_0001_0001_0001); // round up
 }
lib/std/special/compiler_rt.zig
@@ -226,23 +226,26 @@ comptime {
     @export(__addsf3, .{ .name = "__addsf3", .linkage = linkage });
     const __adddf3 = @import("compiler_rt/addXf3.zig").__adddf3;
     @export(__adddf3, .{ .name = "__adddf3", .linkage = linkage });
-    const __addtf3 = @import("compiler_rt/addXf3.zig").__addtf3;
-    @export(__addtf3, .{ .name = "__addtf3", .linkage = linkage });
     const __addxf3 = @import("compiler_rt/addXf3.zig").__addxf3;
     @export(__addxf3, .{ .name = "__addxf3", .linkage = linkage });
+    const __addtf3 = @import("compiler_rt/addXf3.zig").__addtf3;
+    @export(__addtf3, .{ .name = "__addtf3", .linkage = linkage });
+
     const __subsf3 = @import("compiler_rt/addXf3.zig").__subsf3;
     @export(__subsf3, .{ .name = "__subsf3", .linkage = linkage });
     const __subdf3 = @import("compiler_rt/addXf3.zig").__subdf3;
     @export(__subdf3, .{ .name = "__subdf3", .linkage = linkage });
-    const __subtf3 = @import("compiler_rt/addXf3.zig").__subtf3;
-    @export(__subtf3, .{ .name = "__subtf3", .linkage = linkage });
     const __subxf3 = @import("compiler_rt/addXf3.zig").__subxf3;
     @export(__subxf3, .{ .name = "__subxf3", .linkage = linkage });
+    const __subtf3 = @import("compiler_rt/addXf3.zig").__subtf3;
+    @export(__subtf3, .{ .name = "__subtf3", .linkage = linkage });
 
     const __mulsf3 = @import("compiler_rt/mulXf3.zig").__mulsf3;
     @export(__mulsf3, .{ .name = "__mulsf3", .linkage = linkage });
     const __muldf3 = @import("compiler_rt/mulXf3.zig").__muldf3;
     @export(__muldf3, .{ .name = "__muldf3", .linkage = linkage });
+    const __mulxf3 = @import("compiler_rt/mulXf3.zig").__mulxf3;
+    @export(__mulxf3, .{ .name = "__mulxf3", .linkage = linkage });
     const __multf3 = @import("compiler_rt/mulXf3.zig").__multf3;
     @export(__multf3, .{ .name = "__multf3", .linkage = linkage });