master
  1const builtin = @import("builtin");
  2const math = std.math;
  3const std = @import("std");
  4
  5pub const cast = math.cast;
  6pub const fmax = math.floatMax;
  7pub const fmin = math.floatMin;
  8pub const imax = math.maxInt;
  9pub const imin = math.minInt;
 10pub const inf = math.inf;
 11pub const nan = math.nan;
 12pub const next = math.nextAfter;
 13pub const tmin = math.floatTrueMin;
 14
 15pub const Gpr = switch (builtin.cpu.arch) {
 16    else => unreachable,
 17    .x86 => u32,
 18    .x86_64 => u64,
 19};
 20pub const Sse = if (builtin.cpu.has(.x86, .avx))
 21    @Vector(32, u8)
 22else
 23    @Vector(16, u8);
 24
 25pub fn Scalar(comptime Type: type) type {
 26    return switch (@typeInfo(Type)) {
 27        else => Type,
 28        .vector => |info| info.child,
 29    };
 30}
 31pub fn ChangeScalar(comptime Type: type, comptime NewScalar: type) type {
 32    return switch (@typeInfo(Type)) {
 33        else => NewScalar,
 34        .vector => |vector| @Vector(vector.len, NewScalar),
 35    };
 36}
 37pub fn AsSignedness(comptime Type: type, comptime signedness: std.builtin.Signedness) type {
 38    return switch (@typeInfo(Scalar(Type))) {
 39        .int => |int| ChangeScalar(Type, @Int(signedness, int.bits)),
 40        .float => Type,
 41        else => @compileError(@typeName(Type)),
 42    };
 43}
 44pub fn AddOneBit(comptime Type: type) type {
 45    return ChangeScalar(Type, switch (@typeInfo(Scalar(Type))) {
 46        .int => |int| @Int(int.signedness, 1 + int.bits),
 47        .float => Scalar(Type),
 48        else => @compileError(@typeName(Type)),
 49    });
 50}
 51pub fn DoubleBits(comptime Type: type) type {
 52    return ChangeScalar(Type, switch (@typeInfo(Scalar(Type))) {
 53        .int => |int| @Int(int.signedness, int.bits * 2),
 54        .float => Scalar(Type),
 55        else => @compileError(@typeName(Type)),
 56    });
 57}
 58pub fn RoundBitsUp(comptime Type: type, comptime multiple: u16) type {
 59    return ChangeScalar(Type, switch (@typeInfo(Scalar(Type))) {
 60        .int => |int| @Int(int.signedness, std.mem.alignForward(u16, int.bits, multiple)),
 61        .float => Scalar(Type),
 62        else => @compileError(@typeName(Type)),
 63    });
 64}
 65pub fn Log2Int(comptime Type: type) type {
 66    return ChangeScalar(Type, math.Log2Int(Scalar(Type)));
 67}
 68pub fn Log2IntCeil(comptime Type: type) type {
 69    return ChangeScalar(Type, math.Log2IntCeil(Scalar(Type)));
 70}
 71pub fn splat(comptime Type: type, scalar: Scalar(Type)) Type {
 72    return switch (@typeInfo(Type)) {
 73        else => scalar,
 74        .vector => @splat(scalar),
 75    };
 76}
 77pub fn sign(rhs: anytype) ChangeScalar(@TypeOf(rhs), bool) {
 78    const Int = ChangeScalar(@TypeOf(rhs), switch (@typeInfo(Scalar(@TypeOf(rhs)))) {
 79        .int, .comptime_int => Scalar(@TypeOf(rhs)),
 80        .float => |float| @Int(.signed, float.bits),
 81        else => @compileError(@typeName(@TypeOf(rhs))),
 82    });
 83    return @as(Int, @bitCast(rhs)) < splat(Int, 0);
 84}
 85pub fn select(cond: anytype, lhs: anytype, rhs: @TypeOf(lhs)) @TypeOf(lhs) {
 86    return switch (@typeInfo(@TypeOf(cond))) {
 87        .bool => if (cond) lhs else rhs,
 88        .vector => @select(Scalar(@TypeOf(lhs)), cond, lhs, rhs),
 89        else => @compileError(@typeName(@TypeOf(cond))),
 90    };
 91}
 92
 93pub const Compare = enum { strict, relaxed, approx, approx_int, approx_or_overflow };
 94// noinline for a more helpful stack trace
 95pub noinline fn checkExpected(expected: anytype, actual: @TypeOf(expected), comptime compare: Compare) !void {
 96    const Expected = @TypeOf(expected);
 97    const unexpected = unexpected: switch (@typeInfo(Scalar(Expected))) {
 98        else => expected != actual,
 99        .float => switch (compare) {
100            .strict, .relaxed => {
101                const unequal = (expected != actual) & ((expected == expected) | (actual == actual));
102                break :unexpected switch (compare) {
103                    .strict => unequal | (sign(expected) != sign(actual)),
104                    .relaxed => unequal,
105                    .approx, .approx_int, .approx_or_overflow => comptime unreachable,
106                };
107            },
108            .approx, .approx_int, .approx_or_overflow => {
109                const epsilon = math.floatEps(Scalar(Expected));
110                const tolerance = switch (compare) {
111                    .strict, .relaxed => comptime unreachable,
112                    .approx, .approx_int => @sqrt(epsilon),
113                    .approx_or_overflow => @exp2(@log2(epsilon) * 0.4),
114                };
115                const approx_unequal = @abs(expected - actual) > @max(
116                    @abs(expected) * splat(Expected, tolerance),
117                    splat(Expected, switch (compare) {
118                        .strict, .relaxed => comptime unreachable,
119                        .approx, .approx_or_overflow => tolerance,
120                        .approx_int => 1,
121                    }),
122                );
123                break :unexpected switch (compare) {
124                    .strict, .relaxed => comptime unreachable,
125                    .approx, .approx_int => approx_unequal,
126                    .approx_or_overflow => approx_unequal &
127                        (((@abs(expected) != splat(Expected, inf(Expected))) &
128                            (@abs(actual) != splat(Expected, inf(Expected)))) |
129                            (sign(expected) != sign(actual))),
130                };
131            },
132        },
133        .@"struct" => |@"struct"| inline for (@"struct".fields) |field| {
134            try checkExpected(@field(expected, field.name), @field(actual, field.name), compare);
135        } else return,
136    };
137    if (switch (@typeInfo(Expected)) {
138        else => unexpected,
139        .vector => @reduce(.Or, unexpected),
140    }) return error.Unexpected;
141}
142test checkExpected {
143    if (checkExpected(nan(f16), nan(f16), .strict) == error.Unexpected) return error.Unexpected;
144    if (checkExpected(nan(f16), -nan(f16), .strict) != error.Unexpected) return error.Unexpected;
145    if (checkExpected(@as(f16, 0.0), @as(f16, 0.0), .strict) == error.Unexpected) return error.Unexpected;
146    if (checkExpected(@as(f16, -0.0), @as(f16, -0.0), .strict) == error.Unexpected) return error.Unexpected;
147    if (checkExpected(@as(f16, -0.0), @as(f16, 0.0), .strict) != error.Unexpected) return error.Unexpected;
148    if (checkExpected(@as(f16, 0.0), @as(f16, -0.0), .strict) != error.Unexpected) return error.Unexpected;
149
150    if (checkExpected(nan(f32), nan(f32), .strict) == error.Unexpected) return error.Unexpected;
151    if (checkExpected(nan(f32), -nan(f32), .strict) != error.Unexpected) return error.Unexpected;
152    if (checkExpected(@as(f32, 0.0), @as(f32, 0.0), .strict) == error.Unexpected) return error.Unexpected;
153    if (checkExpected(@as(f32, -0.0), @as(f32, -0.0), .strict) == error.Unexpected) return error.Unexpected;
154    if (checkExpected(@as(f32, -0.0), @as(f32, 0.0), .strict) != error.Unexpected) return error.Unexpected;
155    if (checkExpected(@as(f32, 0.0), @as(f32, -0.0), .strict) != error.Unexpected) return error.Unexpected;
156
157    if (checkExpected(nan(f64), nan(f64), .strict) == error.Unexpected) return error.Unexpected;
158    if (checkExpected(nan(f64), -nan(f64), .strict) != error.Unexpected) return error.Unexpected;
159    if (checkExpected(@as(f64, 0.0), @as(f64, 0.0), .strict) == error.Unexpected) return error.Unexpected;
160    if (checkExpected(@as(f64, -0.0), @as(f64, -0.0), .strict) == error.Unexpected) return error.Unexpected;
161    if (checkExpected(@as(f64, -0.0), @as(f64, 0.0), .strict) != error.Unexpected) return error.Unexpected;
162    if (checkExpected(@as(f64, 0.0), @as(f64, -0.0), .strict) != error.Unexpected) return error.Unexpected;
163
164    if (checkExpected(nan(f80), nan(f80), .strict) == error.Unexpected) return error.Unexpected;
165    if (checkExpected(nan(f80), -nan(f80), .strict) != error.Unexpected) return error.Unexpected;
166    if (checkExpected(@as(f80, 0.0), @as(f80, 0.0), .strict) == error.Unexpected) return error.Unexpected;
167    if (checkExpected(@as(f80, -0.0), @as(f80, -0.0), .strict) == error.Unexpected) return error.Unexpected;
168    if (checkExpected(@as(f80, -0.0), @as(f80, 0.0), .strict) != error.Unexpected) return error.Unexpected;
169    if (checkExpected(@as(f80, 0.0), @as(f80, -0.0), .strict) != error.Unexpected) return error.Unexpected;
170
171    if (checkExpected(nan(f128), nan(f128), .strict) == error.Unexpected) return error.Unexpected;
172    if (checkExpected(nan(f128), -nan(f128), .strict) != error.Unexpected) return error.Unexpected;
173    if (checkExpected(@as(f128, 0.0), @as(f128, 0.0), .strict) == error.Unexpected) return error.Unexpected;
174    if (checkExpected(@as(f128, -0.0), @as(f128, -0.0), .strict) == error.Unexpected) return error.Unexpected;
175    if (checkExpected(@as(f128, -0.0), @as(f128, 0.0), .strict) != error.Unexpected) return error.Unexpected;
176    if (checkExpected(@as(f128, 0.0), @as(f128, -0.0), .strict) != error.Unexpected) return error.Unexpected;
177}