Commit 4650e5b9fc

mlugg <mlugg@mlugg.co.uk>
2024-09-17 12:00:38
Sema: clean up cmpNumeric
There is one minor language change here, which is that comparisons of the form `comptime_inf < runtime_f32` have their results comptime-known. This is consistent with comparisons against comptime NaN for instance, which are always comptime known. A corresponding behavior test is added. This fixes a bug with int comparison elision which my previous commit somehow triggered. `Sema.compareIntsOnlyPossibleResult` is much cleaner now!
1 parent a5c9221
Changed files (4)
src/Sema.zig
@@ -33814,11 +33814,11 @@ fn cmpNumeric(
     const maybe_lhs_val = try sema.resolveValue(lhs);
     const maybe_rhs_val = try sema.resolveValue(rhs);
 
-    // If the LHS is const, check if there is a guaranteed result which does not depend on ths RHS.
+    // If the LHS is const, check if there is a guaranteed result which does not depend on ths RHS value.
     if (maybe_lhs_val) |lhs_val| {
         // Result based on comparison exceeding type bounds
-        if (!lhs_val.isUndef(zcu) and (lhs_ty.isInt(zcu) or lhs_ty_tag == .comptime_int) and rhs_ty.isInt(zcu)) {
-            if (try sema.compareIntsOnlyPossibleResult(try sema.resolveLazyValue(lhs_val), op, rhs_ty)) |res| {
+        if (!lhs_val.isUndef(zcu) and (lhs_ty_tag == .int or lhs_ty_tag == .comptime_int) and rhs_ty.isInt(zcu)) {
+            if (try sema.compareIntsOnlyPossibleResult(lhs_val, op, rhs_ty)) |res| {
                 return if (res) .bool_true else .bool_false;
             }
         }
@@ -33826,13 +33826,20 @@ fn cmpNumeric(
         if (lhs_val.isNan(zcu)) {
             return if (op == .neq) .bool_true else .bool_false;
         }
+        // Result based on inf comparison to int
+        if (lhs_val.isInf(zcu) and rhs_ty_tag == .int) return switch (op) {
+            .neq => .bool_true,
+            .eq => .bool_false,
+            .gt, .gte => if (lhs_val.isNegativeInf(zcu)) .bool_false else .bool_true,
+            .lt, .lte => if (lhs_val.isNegativeInf(zcu)) .bool_true else .bool_false,
+        };
     }
 
-    // If the RHS is const, check if there is a guaranteed result which does not depend on ths LHS.
+    // If the RHS is const, check if there is a guaranteed result which does not depend on ths LHS value.
     if (maybe_rhs_val) |rhs_val| {
         // Result based on comparison exceeding type bounds
-        if (!rhs_val.isUndef(zcu) and (rhs_ty.isInt(zcu) or rhs_ty_tag == .comptime_int) and lhs_ty.isInt(zcu)) {
-            if (try sema.compareIntsOnlyPossibleResult(try sema.resolveLazyValue(rhs_val), op.reverse(), lhs_ty)) |res| {
+        if (!rhs_val.isUndef(zcu) and (rhs_ty_tag == .int or rhs_ty_tag == .comptime_int) and lhs_ty.isInt(zcu)) {
+            if (try sema.compareIntsOnlyPossibleResult(rhs_val, op.reverse(), lhs_ty)) |res| {
                 return if (res) .bool_true else .bool_false;
             }
         }
@@ -33840,6 +33847,13 @@ fn cmpNumeric(
         if (rhs_val.isNan(zcu)) {
             return if (op == .neq) .bool_true else .bool_false;
         }
+        // Result based on inf comparison to int
+        if (rhs_val.isInf(zcu) and lhs_ty_tag == .int) return switch (op) {
+            .neq => .bool_true,
+            .eq => .bool_false,
+            .gt, .gte => if (rhs_val.isNegativeInf(zcu)) .bool_true else .bool_false,
+            .lt, .lte => if (rhs_val.isNegativeInf(zcu)) .bool_false else .bool_true,
+        };
     }
 
     // Any other comparison depends on both values, so the result is undef if either is undef.
@@ -33889,17 +33903,18 @@ fn cmpNumeric(
         const casted_rhs = try sema.coerce(block, dest_ty, rhs, rhs_src);
         return block.addBinOp(Air.Inst.Tag.fromCmpOp(op, block.float_mode == .optimized), casted_lhs, casted_rhs);
     }
+
     // For mixed unsigned integer sizes, implicit cast both operands to the larger integer.
     // For mixed signed and unsigned integers, implicit cast both operands to a signed
     // integer with + 1 bit.
     // For mixed floats and integers, extract the integer part from the float, cast that to
     // a signed integer with mantissa bits + 1, and if there was any non-integral part of the float,
     // add/subtract 1.
-    const lhs_is_signed = if (try sema.resolveDefinedValue(block, lhs_src, lhs)) |lhs_val|
+    const lhs_is_signed = if (maybe_lhs_val) |lhs_val|
         !(try lhs_val.compareAllWithZeroSema(.gte, pt))
     else
         (lhs_ty.isRuntimeFloat() or lhs_ty.isSignedInt(zcu));
-    const rhs_is_signed = if (try sema.resolveDefinedValue(block, rhs_src, rhs)) |rhs_val|
+    const rhs_is_signed = if (maybe_rhs_val) |rhs_val|
         !(try rhs_val.compareAllWithZeroSema(.gte, pt))
     else
         (rhs_ty.isRuntimeFloat() or rhs_ty.isSignedInt(zcu));
@@ -33908,19 +33923,8 @@ fn cmpNumeric(
     var dest_float_type: ?Type = null;
 
     var lhs_bits: usize = undefined;
-    if (try sema.resolveValueResolveLazy(lhs)) |lhs_val| {
-        if (lhs_val.isUndef(zcu))
-            return pt.undefRef(Type.bool);
-        if (lhs_val.isNan(zcu)) switch (op) {
-            .neq => return .bool_true,
-            else => return .bool_false,
-        };
-        if (lhs_val.isInf(zcu)) switch (op) {
-            .neq => return .bool_true,
-            .eq => return .bool_false,
-            .gt, .gte => return if (lhs_val.isNegativeInf(zcu)) .bool_false else .bool_true,
-            .lt, .lte => return if (lhs_val.isNegativeInf(zcu)) .bool_true else .bool_false,
-        };
+    if (maybe_lhs_val) |unresolved_lhs_val| {
+        const lhs_val = try sema.resolveLazyValue(unresolved_lhs_val);
         if (!rhs_is_signed) {
             switch (lhs_val.orderAgainstZero(zcu)) {
                 .gt => {},
@@ -33966,19 +33970,8 @@ fn cmpNumeric(
     }
 
     var rhs_bits: usize = undefined;
-    if (try sema.resolveValueResolveLazy(rhs)) |rhs_val| {
-        if (rhs_val.isUndef(zcu))
-            return pt.undefRef(Type.bool);
-        if (rhs_val.isNan(zcu)) switch (op) {
-            .neq => return .bool_true,
-            else => return .bool_false,
-        };
-        if (rhs_val.isInf(zcu)) switch (op) {
-            .neq => return .bool_true,
-            .eq => return .bool_false,
-            .gt, .gte => return if (rhs_val.isNegativeInf(zcu)) .bool_true else .bool_false,
-            .lt, .lte => return if (rhs_val.isNegativeInf(zcu)) .bool_false else .bool_true,
-        };
+    if (maybe_rhs_val) |unresolved_rhs_val| {
+        const rhs_val = try sema.resolveLazyValue(unresolved_rhs_val);
         if (!lhs_is_signed) {
             switch (rhs_val.orderAgainstZero(zcu)) {
                 .gt => {},
@@ -34045,90 +34038,49 @@ fn compareIntsOnlyPossibleResult(
     lhs_val: Value,
     op: std.math.CompareOperator,
     rhs_ty: Type,
-) Allocator.Error!?bool {
+) SemaError!?bool {
     const pt = sema.pt;
     const zcu = pt.zcu;
-    const rhs_info = rhs_ty.intInfo(zcu);
-    const vs_zero = lhs_val.orderAgainstZeroSema(pt) catch unreachable;
-    const is_zero = vs_zero == .eq;
-    const is_negative = vs_zero == .lt;
-    const is_positive = vs_zero == .gt;
 
-    // Anything vs. zero-sized type has guaranteed outcome.
-    if (rhs_info.bits == 0) return switch (op) {
-        .eq, .lte, .gte => is_zero,
-        .neq, .lt, .gt => !is_zero,
-    };
+    const min_rhs = try rhs_ty.minInt(pt, rhs_ty);
+    const max_rhs = try rhs_ty.maxInt(pt, rhs_ty);
 
-    // Special case for i1, which can only be 0 or -1.
-    // Zero and positive ints have guaranteed outcome.
-    if (rhs_info.bits == 1 and rhs_info.signedness == .signed) {
-        if (is_positive) return switch (op) {
-            .gt, .gte, .neq => true,
-            .lt, .lte, .eq => false,
-        };
-        if (is_zero) return switch (op) {
-            .gte => true,
-            .lt => false,
-            .gt, .lte, .eq, .neq => null,
-        };
+    if (min_rhs.toIntern() == max_rhs.toIntern()) {
+        // RHS is effectively comptime-known.
+        return try Value.compareHeteroSema(lhs_val, op, min_rhs, pt);
     }
 
-    // Negative vs. unsigned has guaranteed outcome.
-    if (rhs_info.signedness == .unsigned and is_negative) return switch (op) {
-        .eq, .gt, .gte => false,
-        .neq, .lt, .lte => true,
-    };
-
-    const sign_adj = @intFromBool(!is_negative and rhs_info.signedness == .signed);
-    const req_bits = lhs_val.intBitCountTwosComp(zcu) + sign_adj;
-
-    // No sized type can have more than 65535 bits.
-    // The RHS type operand is either a runtime value or sized (but undefined) constant.
-    if (req_bits > 65535) return switch (op) {
-        .lt, .lte => is_negative,
-        .gt, .gte => is_positive,
-        .eq => false,
-        .neq => true,
-    };
-    const fits = req_bits <= rhs_info.bits;
+    const against_min = try lhs_val.orderAdvanced(min_rhs, .sema, zcu, pt.tid);
+    const against_max = try lhs_val.orderAdvanced(max_rhs, .sema, zcu, pt.tid);
 
-    // Oversized int has guaranteed outcome.
     switch (op) {
-        .eq => return if (!fits) false else null,
-        .neq => return if (!fits) true else null,
-        .lt, .lte => if (!fits) return is_negative,
-        .gt, .gte => if (!fits) return !is_negative,
+        .eq => {
+            if (against_min.compare(.lt)) return false;
+            if (against_max.compare(.gt)) return false;
+        },
+        .neq => {
+            if (against_min.compare(.lt)) return true;
+            if (against_max.compare(.gt)) return true;
+        },
+        .lt => {
+            if (against_min.compare(.lt)) return true;
+            if (against_max.compare(.gte)) return false;
+        },
+        .gt => {
+            if (against_max.compare(.gt)) return true;
+            if (against_min.compare(.lte)) return false;
+        },
+        .lte => {
+            if (against_min.compare(.lte)) return true;
+            if (against_max.compare(.gt)) return false;
+        },
+        .gte => {
+            if (against_max.compare(.gte)) return true;
+            if (against_min.compare(.lt)) return false;
+        },
     }
 
-    // For any other comparison, we need to know if the LHS value is
-    // equal to the maximum or minimum possible value of the RHS type.
-    const is_min, const is_max = edge: {
-        if (is_zero and rhs_info.signedness == .unsigned) break :edge .{ true, false };
-
-        if (req_bits != rhs_info.bits) break :edge .{ false, false };
-
-        const ty = try pt.intType(
-            if (is_negative) .signed else .unsigned,
-            @intCast(req_bits),
-        );
-        const pop_count = lhs_val.popCount(ty, zcu);
-
-        if (is_negative) {
-            break :edge .{ pop_count == 1, false };
-        } else {
-            break :edge .{ false, pop_count == req_bits - sign_adj };
-        }
-    };
-
-    assert(fits);
-    return switch (op) {
-        .lt => if (is_max) false else null,
-        .lte => if (is_min) true else null,
-        .gt => if (is_min) false else null,
-        .gte => if (is_max) true else null,
-        .eq, .neq => unreachable,
-    };
+    return null;
 }
 
 /// Asserts that lhs and rhs types are both vectors.
src/Type.zig
@@ -3040,8 +3040,7 @@ pub fn minInt(ty: Type, pt: Zcu.PerThread, dest_ty: Type) !Value {
 pub fn minIntScalar(ty: Type, pt: Zcu.PerThread, dest_ty: Type) !Value {
     const zcu = pt.zcu;
     const info = ty.intInfo(zcu);
-    if (info.signedness == .unsigned) return pt.intValue(dest_ty, 0);
-    if (info.bits == 0) return pt.intValue(dest_ty, -1);
+    if (info.signedness == .unsigned or info.bits == 0) return pt.intValue(dest_ty, 0);
 
     if (std.math.cast(u6, info.bits - 1)) |shift| {
         const n = @as(i64, std.math.minInt(i64)) >> (63 - shift);
@@ -3072,10 +3071,7 @@ pub fn maxIntScalar(ty: Type, pt: Zcu.PerThread, dest_ty: Type) !Value {
     const info = ty.intInfo(pt.zcu);
 
     switch (info.bits) {
-        0 => return switch (info.signedness) {
-            .signed => try pt.intValue(dest_ty, -1),
-            .unsigned => try pt.intValue(dest_ty, 0),
-        },
+        0 => return pt.intValue(dest_ty, 0),
         1 => return switch (info.signedness) {
             .signed => try pt.intValue(dest_ty, 0),
             .unsigned => try pt.intValue(dest_ty, 1),
src/Value.zig
@@ -191,7 +191,7 @@ pub fn toBigIntAdvanced(
     comptime strat: ResolveStrat,
     zcu: *Zcu,
     tid: strat.Tid(),
-) Zcu.CompileError!BigIntConst {
+) Zcu.SemaError!BigIntConst {
     const ip = &zcu.intern_pool;
     return switch (val.toIntern()) {
         .bool_false => BigIntMutable.init(&space.limbs, 0).toConst(),
@@ -1038,7 +1038,7 @@ pub fn orderAgainstZeroInner(
     comptime strat: ResolveStrat,
     zcu: *Zcu,
     tid: strat.Tid(),
-) Zcu.CompileError!std.math.Order {
+) Zcu.SemaError!std.math.Order {
     return switch (lhs.toIntern()) {
         .bool_false => .eq,
         .bool_true => .gt,
test/behavior/math.zig
@@ -1729,3 +1729,65 @@ test "@clz works on both vector and scalar inputs" {
     try std.testing.expectEqual(@as(u6, 31), a);
     try std.testing.expectEqual([_]u6{ 31, 31, 31, 31 }, b);
 }
+
+test "runtime comparison to NaN is comptime-known" {
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+    if (builtin.cpu.arch.isArmOrThumb() and builtin.target.floatAbi() == .soft) return error.SkipZigTest; // https://github.com/ziglang/zig/issues/21234
+
+    const S = struct {
+        fn doTheTest(comptime F: type, x: F) void {
+            const nan = math.nan(F);
+            if (!(nan != x)) comptime unreachable;
+            if (nan == x) comptime unreachable;
+            if (nan > x) comptime unreachable;
+            if (nan < x) comptime unreachable;
+            if (nan >= x) comptime unreachable;
+            if (nan <= x) comptime unreachable;
+        }
+    };
+
+    S.doTheTest(f16, 123.0);
+    S.doTheTest(f32, 123.0);
+    S.doTheTest(f64, 123.0);
+    S.doTheTest(f128, 123.0);
+    comptime S.doTheTest(f16, 123.0);
+    comptime S.doTheTest(f32, 123.0);
+    comptime S.doTheTest(f64, 123.0);
+    comptime S.doTheTest(f128, 123.0);
+}
+
+test "runtime int comparison to inf is comptime-known" {
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+    if (builtin.cpu.arch.isArmOrThumb() and builtin.target.floatAbi() == .soft) return error.SkipZigTest; // https://github.com/ziglang/zig/issues/21234
+
+    const S = struct {
+        fn doTheTest(comptime F: type, x: u32) void {
+            const inf = math.inf(F);
+            if (!(inf != x)) comptime unreachable;
+            if (inf == x) comptime unreachable;
+            if (x > inf) comptime unreachable;
+            if (x >= inf) comptime unreachable;
+            if (!(x < inf)) comptime unreachable;
+            if (!(x <= inf)) comptime unreachable;
+        }
+    };
+
+    S.doTheTest(f16, 123);
+    S.doTheTest(f32, 123);
+    S.doTheTest(f64, 123);
+    S.doTheTest(f128, 123);
+    comptime S.doTheTest(f16, 123);
+    comptime S.doTheTest(f32, 123);
+    comptime S.doTheTest(f64, 123);
+    comptime S.doTheTest(f128, 123);
+}