Commit ea590ece4b

kkHAIKE <kkhaike@gmail.com>
2022-11-18 21:06:49
Sema: optimize compare comptime float with int
1 parent 04f3067
Changed files (3)
src
test
behavior
src/Sema.zig
@@ -13896,9 +13896,25 @@ fn analyzeArithmetic(
                 // because there is a possible value for which the addition would
                 // overflow (max_int), causing illegal behavior.
                 // For floats: either operand being undef makes the result undef.
+                // If either of the operands are inf, and the other operand is zero,
+                // the result is nan.
+                // If either of the operands are nan, the result is nan.
                 if (maybe_lhs_val) |lhs_val| {
                     if (!lhs_val.isUndef()) {
-                        if (try lhs_val.compareAllWithZeroAdvanced(.eq, sema)) {
+                        if (lhs_val.isNan()) {
+                            return sema.addConstant(resolved_type, lhs_val);
+                        }
+                        if (try lhs_val.compareAllWithZeroAdvanced(.eq, sema)) lz: {
+                            if (maybe_rhs_val) |rhs_val| {
+                                if (rhs_val.isNan()) {
+                                    return sema.addConstant(resolved_type, rhs_val);
+                                }
+                                if (rhs_val.isInf()) {
+                                    return sema.addConstant(resolved_type, try Value.Tag.float_32.create(sema.arena, std.math.nan_f32));
+                                }
+                            } else if (resolved_type.isAnyFloat()) {
+                                break :lz;
+                            }
                             const zero_val = if (is_vector) b: {
                                 break :b try Value.Tag.repeated.create(sema.arena, Value.zero);
                             } else Value.zero;
@@ -13918,7 +13934,17 @@ fn analyzeArithmetic(
                             return sema.addConstUndef(resolved_type);
                         }
                     }
-                    if (try rhs_val.compareAllWithZeroAdvanced(.eq, sema)) {
+                    if (rhs_val.isNan()) {
+                        return sema.addConstant(resolved_type, rhs_val);
+                    }
+                    if (try rhs_val.compareAllWithZeroAdvanced(.eq, sema)) rz: {
+                        if (maybe_lhs_val) |lhs_val| {
+                            if (lhs_val.isInf()) {
+                                return sema.addConstant(resolved_type, try Value.Tag.float_32.create(sema.arena, std.math.nan_f32));
+                            }
+                        } else if (resolved_type.isAnyFloat()) {
+                            break :rz;
+                        }
                         const zero_val = if (is_vector) b: {
                             break :b try Value.Tag.repeated.create(sema.arena, Value.zero);
                         } else Value.zero;
@@ -28175,8 +28201,10 @@ fn cmpNumeric(
             else => return Air.Inst.Ref.bool_false,
         };
         if (lhs_val.isInf()) switch (op) {
-            .gt, .neq => return Air.Inst.Ref.bool_true,
-            .lt, .lte, .eq, .gte => return Air.Inst.Ref.bool_false,
+            .neq => return Air.Inst.Ref.bool_true,
+            .eq => return Air.Inst.Ref.bool_false,
+            .gt, .gte => return if (lhs_val.isNegativeInf()) Air.Inst.Ref.bool_false else Air.Inst.Ref.bool_true,
+            .lt, .lte => return if (lhs_val.isNegativeInf()) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false,
         };
         if (!rhs_is_signed) {
             switch (lhs_val.orderAgainstZero()) {
@@ -28193,14 +28221,17 @@ fn cmpNumeric(
             }
         }
         if (lhs_is_float) {
-            var bigint = try float128IntPartToBigInt(sema.gpa, lhs_val.toFloat(f128));
-            defer bigint.deinit();
             if (lhs_val.floatHasFraction()) {
                 switch (op) {
                     .eq => return Air.Inst.Ref.bool_false,
                     .neq => return Air.Inst.Ref.bool_true,
                     else => {},
                 }
+            }
+
+            var bigint = try float128IntPartToBigInt(sema.gpa, lhs_val.toFloat(f128));
+            defer bigint.deinit();
+            if (lhs_val.floatHasFraction()) {
                 if (lhs_is_signed) {
                     try bigint.addScalar(&bigint, -1);
                 } else {
@@ -28228,8 +28259,10 @@ fn cmpNumeric(
             else => return Air.Inst.Ref.bool_false,
         };
         if (rhs_val.isInf()) switch (op) {
-            .lt, .neq => return Air.Inst.Ref.bool_true,
-            .gt, .lte, .eq, .gte => return Air.Inst.Ref.bool_false,
+            .neq => return Air.Inst.Ref.bool_true,
+            .eq => return Air.Inst.Ref.bool_false,
+            .gt, .gte => return if (rhs_val.isNegativeInf()) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false,
+            .lt, .lte => return if (rhs_val.isNegativeInf()) Air.Inst.Ref.bool_false else Air.Inst.Ref.bool_true,
         };
         if (!lhs_is_signed) {
             switch (rhs_val.orderAgainstZero()) {
@@ -28246,14 +28279,17 @@ fn cmpNumeric(
             }
         }
         if (rhs_is_float) {
-            var bigint = try float128IntPartToBigInt(sema.gpa, rhs_val.toFloat(f128));
-            defer bigint.deinit();
             if (rhs_val.floatHasFraction()) {
                 switch (op) {
                     .eq => return Air.Inst.Ref.bool_false,
                     .neq => return Air.Inst.Ref.bool_true,
                     else => {},
                 }
+            }
+
+            var bigint = try float128IntPartToBigInt(sema.gpa, rhs_val.toFloat(f128));
+            defer bigint.deinit();
+            if (rhs_val.floatHasFraction()) {
                 if (rhs_is_signed) {
                     try bigint.addScalar(&bigint, -1);
                 } else {
src/value.zig
@@ -2081,6 +2081,15 @@ pub const Value = extern union {
         op: std.math.CompareOperator,
         opt_sema: ?*Sema,
     ) Module.CompileError!bool {
+        if (lhs.isInf()) {
+            switch (op) {
+                .neq => return true,
+                .eq => return false,
+                .gt, .gte => return !lhs.isNegativeInf(),
+                .lt, .lte => return lhs.isNegativeInf(),
+            }
+        }
+
         switch (lhs.tag()) {
             .repeated => return lhs.castTag(.repeated).?.data.compareAllWithZeroAdvanced(op, opt_sema),
             .aggregate => {
@@ -2089,11 +2098,11 @@ pub const Value = extern union {
                 }
                 return true;
             },
-            .float_16 => if (std.math.isNan(lhs.castTag(.float_16).?.data)) return op != .neq,
-            .float_32 => if (std.math.isNan(lhs.castTag(.float_32).?.data)) return op != .neq,
-            .float_64 => if (std.math.isNan(lhs.castTag(.float_64).?.data)) return op != .neq,
-            .float_80 => if (std.math.isNan(lhs.castTag(.float_80).?.data)) return op != .neq,
-            .float_128 => if (std.math.isNan(lhs.castTag(.float_128).?.data)) return op != .neq,
+            .float_16 => if (std.math.isNan(lhs.castTag(.float_16).?.data)) return op == .neq,
+            .float_32 => if (std.math.isNan(lhs.castTag(.float_32).?.data)) return op == .neq,
+            .float_64 => if (std.math.isNan(lhs.castTag(.float_64).?.data)) return op == .neq,
+            .float_80 => if (std.math.isNan(lhs.castTag(.float_80).?.data)) return op == .neq,
+            .float_128 => if (std.math.isNan(lhs.castTag(.float_128).?.data)) return op == .neq,
             else => {},
         }
         return (try orderAgainstZeroAdvanced(lhs, opt_sema)).compare(op);
@@ -3817,6 +3826,17 @@ pub const Value = extern union {
         };
     }
 
+    pub fn isNegativeInf(val: Value) bool {
+        return switch (val.tag()) {
+            .float_16 => std.math.isNegativeInf(val.castTag(.float_16).?.data),
+            .float_32 => std.math.isNegativeInf(val.castTag(.float_32).?.data),
+            .float_64 => std.math.isNegativeInf(val.castTag(.float_64).?.data),
+            .float_80 => std.math.isNegativeInf(val.castTag(.float_80).?.data),
+            .float_128 => std.math.isNegativeInf(val.castTag(.float_128).?.data),
+            else => false,
+        };
+    }
+
     pub fn floatRem(lhs: Value, rhs: Value, float_type: Type, arena: Allocator, target: Target) !Value {
         if (float_type.zigTypeTag() == .Vector) {
             const result_data = try arena.alloc(Value, float_type.vectorLen());
test/behavior/bugs/12891.zig
@@ -18,3 +18,69 @@ test "inf" {
     var i: usize = 0;
     try std.testing.expect(f > i);
 }
+test "-inf < 0" {
+    const f = comptime -std.math.inf(f64);
+    var i: usize = 0;
+    try std.testing.expect(f < i);
+}
+test "inf >= 1" {
+    const f = comptime std.math.inf(f64);
+    var i: usize = 1;
+    try std.testing.expect(f >= i);
+}
+test "isNan(nan * 1)" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+
+    const nan_times_one = comptime std.math.nan(f64) * 1;
+    try std.testing.expect(std.math.isNan(nan_times_one));
+}
+test "runtime isNan(nan * 1)" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+
+    const nan_times_one = std.math.nan(f64) * 1;
+    try std.testing.expect(std.math.isNan(nan_times_one));
+}
+test "isNan(nan * 0)" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+
+    const nan_times_zero = comptime std.math.nan(f64) * 0;
+    try std.testing.expect(std.math.isNan(nan_times_zero));
+    const zero_times_nan = 0 * comptime std.math.nan(f64);
+    try std.testing.expect(std.math.isNan(zero_times_nan));
+}
+test "isNan(inf * 0)" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+
+    const inf_times_zero = comptime std.math.inf(f64) * 0;
+    try std.testing.expect(std.math.isNan(inf_times_zero));
+    const zero_times_inf = 0 * comptime std.math.inf(f64);
+    try std.testing.expect(std.math.isNan(zero_times_inf));
+}
+test "runtime isNan(nan * 0)" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+
+    const nan_times_zero = std.math.nan(f64) * 0;
+    try std.testing.expect(std.math.isNan(nan_times_zero));
+    const zero_times_nan = 0 * std.math.nan(f64);
+    try std.testing.expect(std.math.isNan(zero_times_nan));
+}
+test "runtime isNan(inf * 0)" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+
+    const inf_times_zero = std.math.inf(f64) * 0;
+    try std.testing.expect(std.math.isNan(inf_times_zero));
+    const zero_times_inf = 0 * std.math.inf(f64);
+    try std.testing.expect(std.math.isNan(zero_times_inf));
+}