Commit c5b96c7447

Jacob Young <jacobly0@users.noreply.github.com>
2023-05-10 04:23:44
llvm: fix `@max`/`@min` of unsupported float types
Closes #15611
1 parent 2d2d79a
Changed files (6)
lib/std/math/copysign.zig
@@ -4,16 +4,17 @@ const expect = std.testing.expect;
 
 /// Returns a value with the magnitude of `magnitude` and the sign of `sign`.
 pub fn copysign(magnitude: anytype, sign: @TypeOf(magnitude)) @TypeOf(magnitude) {
-    const T = @TypeOf(magnitude);
-    const TBits = std.meta.Int(.unsigned, @typeInfo(T).Float.bits);
-    const sign_bit_mask = @as(TBits, 1) << (@bitSizeOf(T) - 1);
-    const mag = @bitCast(TBits, magnitude) & ~sign_bit_mask;
-    const sgn = @bitCast(TBits, sign) & sign_bit_mask;
-    return @bitCast(T, mag | sgn);
+    const bits = math.floatBits(@TypeOf(magnitude));
+    const FBits = @Type(.{ .Float = .{ .bits = bits } });
+    const TBits = @Type(.{ .Int = .{ .signedness = .unsigned, .bits = bits } });
+    const sign_bit_mask = @as(TBits, 1) << (bits - 1);
+    const mag = @bitCast(TBits, @as(FBits, magnitude)) & ~sign_bit_mask;
+    const sgn = @bitCast(TBits, @as(FBits, sign)) & sign_bit_mask;
+    return @bitCast(FBits, mag | sgn);
 }
 
 test "math.copysign" {
-    inline for ([_]type{ f16, f32, f64, f80, f128 }) |T| {
+    inline for ([_]type{ f16, f32, f64, f80, f128, c_longdouble, comptime_float }) |T| {
         try expect(copysign(@as(T, 1.0), @as(T, 1.0)) == 1.0);
         try expect(copysign(@as(T, 2.0), @as(T, -2.0)) == -2.0);
         try expect(copysign(@as(T, -3.0), @as(T, 3.0)) == 3.0);
lib/std/math/float.zig
@@ -4,21 +4,29 @@ const expect = std.testing.expect;
 
 /// Creates a raw "1.0" mantissa for floating point type T. Used to dedupe f80 logic.
 inline fn mantissaOne(comptime T: type) comptime_int {
-    return if (@typeInfo(T).Float.bits == 80) 1 << floatFractionalBits(T) else 0;
+    return 1 << floatFractionalBits(T) & ((1 << floatMantissaBits(T)) - 1);
 }
 
 /// Creates floating point type T from an unbiased exponent and raw mantissa.
 inline fn reconstructFloat(comptime T: type, comptime exponent: comptime_int, comptime mantissa: comptime_int) T {
-    const TBits = @Type(.{ .Int = .{ .signedness = .unsigned, .bits = @bitSizeOf(T) } });
+    const FBits = @Type(.{ .Float = .{ .bits = floatBits(T) } });
+    const TBits = @Type(.{ .Int = .{ .signedness = .unsigned, .bits = floatBits(T) } });
     const biased_exponent = @as(TBits, exponent + floatExponentMax(T));
-    return @bitCast(T, (biased_exponent << floatMantissaBits(T)) | @as(TBits, mantissa));
+    return @bitCast(FBits, (biased_exponent << floatMantissaBits(T)) | @as(TBits, mantissa));
+}
+
+/// Returns the number of bits in floating point type T.
+pub inline fn floatBits(comptime T: type) comptime_int {
+    return switch (@typeInfo(T)) {
+        .Float => |info| info.bits,
+        .ComptimeFloat => 128,
+        else => @compileError(@typeName(T) ++ " is not a floating point type"),
+    };
 }
 
 /// Returns the number of bits in the exponent of floating point type T.
 pub inline fn floatExponentBits(comptime T: type) comptime_int {
-    comptime assert(@typeInfo(T) == .Float);
-
-    return switch (@typeInfo(T).Float.bits) {
+    return switch (floatBits(T)) {
         16 => 5,
         32 => 8,
         64 => 11,
@@ -30,9 +38,7 @@ pub inline fn floatExponentBits(comptime T: type) comptime_int {
 
 /// Returns the number of bits in the mantissa of floating point type T.
 pub inline fn floatMantissaBits(comptime T: type) comptime_int {
-    comptime assert(@typeInfo(T) == .Float);
-
-    return switch (@typeInfo(T).Float.bits) {
+    return switch (floatBits(T)) {
         16 => 10,
         32 => 23,
         64 => 52,
@@ -44,12 +50,10 @@ pub inline fn floatMantissaBits(comptime T: type) comptime_int {
 
 /// Returns the number of fractional bits in the mantissa of floating point type T.
 pub inline fn floatFractionalBits(comptime T: type) comptime_int {
-    comptime assert(@typeInfo(T) == .Float);
-
     // standard IEEE floats have an implicit 0.m or 1.m integer part
     // f80 is special and has an explicitly stored bit in the MSB
     // this function corresponds to `MANT_DIG - 1' from C
-    return switch (@typeInfo(T).Float.bits) {
+    return switch (floatBits(T)) {
         16 => 10,
         32 => 23,
         64 => 52,
@@ -101,6 +105,7 @@ test "float bits" {
     inline for ([_]type{ f16, f32, f64, f80, f128, c_longdouble }) |T| {
         // (1 +) for the sign bit, since it is separate from the other bits
         const size = 1 + floatExponentBits(T) + floatMantissaBits(T);
+        try expect(floatBits(T) == size);
         try expect(@bitSizeOf(T) == size);
 
         // for machine epsilon, assert expmin <= -prec <= expmax
lib/std/math/nan.zig
@@ -2,13 +2,13 @@ const math = @import("../math.zig");
 
 /// Returns the nan representation for type T.
 pub inline fn nan(comptime T: type) T {
-    return switch (@typeInfo(T).Float.bits) {
+    return switch (math.floatBits(T)) {
         16 => math.nan_f16,
         32 => math.nan_f32,
         64 => math.nan_f64,
         80 => math.nan_f80,
         128 => math.nan_f128,
-        else => @compileError("unreachable"),
+        else => @compileError("unknown floating point type " ++ @typeName(T)),
     };
 }
 
lib/std/math.zig
@@ -37,6 +37,7 @@ pub const sqrt2 = 1.414213562373095048801688724209698079;
 /// 1/sqrt(2)
 pub const sqrt1_2 = 0.707106781186547524400844362104849039;
 
+pub const floatBits = @import("math/float.zig").floatBits;
 pub const floatExponentBits = @import("math/float.zig").floatExponentBits;
 pub const floatMantissaBits = @import("math/float.zig").floatMantissaBits;
 pub const floatFractionalBits = @import("math/float.zig").floatFractionalBits;
src/codegen/llvm.zig
@@ -7034,7 +7034,7 @@ pub const FuncGen = struct {
         const rhs = try self.resolveInst(bin_op.rhs);
         const scalar_ty = self.air.typeOfIndex(inst).scalarType();
 
-        if (scalar_ty.isAnyFloat()) return self.builder.buildMinNum(lhs, rhs, "");
+        if (scalar_ty.isAnyFloat()) return self.buildFloatOp(.fmin, scalar_ty, 2, .{ lhs, rhs });
         if (scalar_ty.isSignedInt()) return self.builder.buildSMin(lhs, rhs, "");
         return self.builder.buildUMin(lhs, rhs, "");
     }
@@ -7045,7 +7045,7 @@ pub const FuncGen = struct {
         const rhs = try self.resolveInst(bin_op.rhs);
         const scalar_ty = self.air.typeOfIndex(inst).scalarType();
 
-        if (scalar_ty.isAnyFloat()) return self.builder.buildMaxNum(lhs, rhs, "");
+        if (scalar_ty.isAnyFloat()) return self.buildFloatOp(.fmax, scalar_ty, 2, .{ lhs, rhs });
         if (scalar_ty.isSignedInt()) return self.builder.buildSMax(lhs, rhs, "");
         return self.builder.buildUMax(lhs, rhs, "");
     }
test/behavior/maximum_minimum.zig
@@ -96,6 +96,31 @@ test "@min for vectors" {
     comptime try S.doTheTest();
 }
 
+test "@min/max for floats" {
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+
+    const S = struct {
+        fn doTheTest(comptime T: type) !void {
+            var x: T = -3.14;
+            var y: T = 5.27;
+            try expectEqual(x, @min(x, y));
+            try expectEqual(x, @min(y, x));
+            try expectEqual(y, @max(x, y));
+            try expectEqual(y, @max(y, x));
+        }
+    };
+
+    inline for (.{ f16, f32, f64, f80, f128, c_longdouble }) |T| {
+        try S.doTheTest(T);
+        comptime try S.doTheTest(T);
+    }
+    comptime try S.doTheTest(comptime_float);
+}
+
 test "@min/@max on lazy values" {
     const A = extern struct { u8_4: [4]u8 };
     const B = extern struct { u8_16: [16]u8 };