Commit e57e835904

Stevie Hryciw <codroid@gmail.com>
2022-11-15 00:56:59
Sema: elide integer comparisons with guaranteed outcomes
1 parent 477038a
src/Sema.zig
@@ -28460,6 +28460,17 @@ fn cmpNumeric(
     const runtime_src: LazySrcLoc = src: {
         if (try sema.resolveMaybeUndefVal(lhs)) |lhs_val| {
             if (try sema.resolveMaybeUndefVal(rhs)) |rhs_val| {
+                // Compare ints: const vs. undefined (or vice versa)
+                if (!lhs_val.isUndef() and (lhs_ty.isInt() or lhs_ty_tag == .ComptimeInt) and rhs_ty.isInt() and rhs_val.isUndef()) {
+                    if (sema.compareIntsOnlyPossibleResult(target, lhs_val, op, rhs_ty)) |res| {
+                        return if (res) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false;
+                    }
+                } else if (!rhs_val.isUndef() and (rhs_ty.isInt() or rhs_ty_tag == .ComptimeInt) and lhs_ty.isInt() and lhs_val.isUndef()) {
+                    if (sema.compareIntsOnlyPossibleResult(target, rhs_val, op.reverse(), lhs_ty)) |res| {
+                        return if (res) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false;
+                    }
+                }
+
                 if (lhs_val.isUndef() or rhs_val.isUndef()) {
                     return sema.addConstUndef(Type.bool);
                 }
@@ -28476,9 +28487,23 @@ fn cmpNumeric(
                     return Air.Inst.Ref.bool_false;
                 }
             } else {
+                if (!lhs_val.isUndef() and (lhs_ty.isInt() or lhs_ty_tag == .ComptimeInt) and rhs_ty.isInt()) {
+                    // Compare ints: const vs. var
+                    if (sema.compareIntsOnlyPossibleResult(target, lhs_val, op, rhs_ty)) |res| {
+                        return if (res) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false;
+                    }
+                }
                 break :src rhs_src;
             }
         } else {
+            if (try sema.resolveMaybeUndefVal(rhs)) |rhs_val| {
+                if (!rhs_val.isUndef() and (rhs_ty.isInt() or rhs_ty_tag == .ComptimeInt) and lhs_ty.isInt()) {
+                    // Compare ints: var vs. const
+                    if (sema.compareIntsOnlyPossibleResult(target, rhs_val, op.reverse(), lhs_ty)) |res| {
+                        return if (res) Air.Inst.Ref.bool_true else Air.Inst.Ref.bool_false;
+                    }
+                }
+            }
             break :src lhs_src;
         }
     };
@@ -28667,6 +28692,107 @@ fn cmpNumeric(
     return block.addBinOp(Air.Inst.Tag.fromCmpOp(op, block.float_mode == .Optimized), casted_lhs, casted_rhs);
 }
 
+/// Asserts that LHS value is an int or comptime int and not undefined, and that RHS type is an int.
+/// Given a const LHS and an unknown RHS, attempt to determine whether `op` has a guaranteed result.
+/// If it cannot be determined, returns null.
+/// Otherwise returns a bool for the guaranteed comparison operation.
+fn compareIntsOnlyPossibleResult(sema: *Sema, target: std.Target, lhs_val: Value, op: std.math.CompareOperator, rhs_ty: Type) ?bool {
+    const rhs_info = rhs_ty.intInfo(target);
+    const vs_zero = lhs_val.orderAgainstZeroAdvanced(sema) catch unreachable;
+    const is_zero = vs_zero == .eq;
+    const is_negative = vs_zero == .lt;
+    const is_positive = vs_zero == .gt;
+
+    // Anything vs. zero-sized type has guaranteed outcome.
+    if (rhs_info.bits == 0) return switch (op) {
+        .eq, .lte, .gte => is_zero,
+        .neq, .lt, .gt => !is_zero,
+    };
+
+    // Special case for i1, which can only be 0 or -1.
+    // Zero and positive ints have guaranteed outcome.
+    if (rhs_info.bits == 1 and rhs_info.signedness == .signed) {
+        if (is_positive) return switch (op) {
+            .gt, .gte, .neq => true,
+            .lt, .lte, .eq => false,
+        };
+        if (is_zero) return switch (op) {
+            .gte => true,
+            .lt => false,
+            .gt, .lte, .eq, .neq => null,
+        };
+    }
+
+    // Negative vs. unsigned has guaranteed outcome.
+    if (rhs_info.signedness == .unsigned and is_negative) return switch (op) {
+        .eq, .gt, .gte => false,
+        .neq, .lt, .lte => true,
+    };
+
+    const sign_adj = @boolToInt(!is_negative and rhs_info.signedness == .signed);
+    const req_bits = lhs_val.intBitCountTwosComp(target) + sign_adj;
+
+    // No sized type can have more than 65535 bits.
+    // The RHS type operand is either a runtime value or sized (but undefined) constant.
+    if (req_bits > 65535) return switch (op) {
+        .lt, .lte => is_negative,
+        .gt, .gte => is_positive,
+        .eq => false,
+        .neq => true,
+    };
+    const fits = req_bits <= rhs_info.bits;
+
+    // Oversized int has guaranteed outcome.
+    switch (op) {
+        .eq => return if (!fits) false else null,
+        .neq => return if (!fits) true else null,
+        .lt, .lte => if (!fits) return is_negative,
+        .gt, .gte => if (!fits) return !is_negative,
+    }
+
+    // For any other comparison, we need to know if the LHS value is
+    // equal to the maximum or minimum possible value of the RHS type.
+    const edge: struct { min: bool, max: bool } = edge: {
+        if (is_zero and rhs_info.signedness == .unsigned) break :edge .{
+            .min = true,
+            .max = false,
+        };
+
+        if (req_bits != rhs_info.bits) break :edge .{
+            .min = false,
+            .max = false,
+        };
+
+        var ty_buffer: Type.Payload.Bits = .{
+            .base = .{ .tag = if (is_negative) .int_signed else .int_unsigned },
+            .data = @intCast(u16, req_bits),
+        };
+        const ty = Type.initPayload(&ty_buffer.base);
+        const pop_count = lhs_val.popCount(ty, target);
+
+        if (is_negative) {
+            break :edge .{
+                .min = pop_count == 1,
+                .max = false,
+            };
+        } else {
+            break :edge .{
+                .min = false,
+                .max = pop_count == req_bits - sign_adj,
+            };
+        }
+    };
+
+    assert(fits);
+    return switch (op) {
+        .lt => if (edge.max) false else null,
+        .lte => if (edge.min) true else null,
+        .gt => if (edge.min) false else null,
+        .gte => if (edge.max) true else null,
+        .eq, .neq => unreachable,
+    };
+}
+
 /// Asserts that lhs and rhs types are both vectors.
 fn cmpVector(
     sema: *Sema,
src/value.zig
@@ -1756,17 +1756,8 @@ pub const Value = extern union {
                 const info = ty.intInfo(target);
 
                 var buffer: Value.BigIntSpace = undefined;
-                const operand_bigint = val.toBigInt(&buffer, target);
-
-                var limbs_buffer: [4]std.math.big.Limb = undefined;
-                var result_bigint = BigIntMutable{
-                    .limbs = &limbs_buffer,
-                    .positive = undefined,
-                    .len = undefined,
-                };
-                result_bigint.popCount(operand_bigint, info.bits);
-
-                return result_bigint.toConst().to(u64) catch unreachable;
+                const int = val.toBigInt(&buffer, target);
+                return @intCast(u64, int.popCount(info.bits));
             },
         }
     }
test/behavior/int_comparison_elision.zig
@@ -0,0 +1,108 @@
+const std = @import("std");
+const minInt = std.math.minInt;
+const maxInt = std.math.maxInt;
+const builtin = @import("builtin");
+
+test "int comparison elision" {
+    testIntEdges(u0);
+    testIntEdges(i0);
+    testIntEdges(u1);
+    testIntEdges(i1);
+    testIntEdges(u4);
+    testIntEdges(i4);
+
+    // TODO: support int types > 128 bits wide in other backends
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+
+    // TODO: panic: integer overflow with int types > 65528 bits wide
+    // TODO: LLVM generates too many parameters for wasmtime when splitting up int > 64000 bits wide
+    testIntEdges(u64000);
+    testIntEdges(i64000);
+}
+
+// All comparisons in this test have a guaranteed result,
+// so one branch of each 'if' should never be analyzed.
+fn testIntEdges(comptime T: type) void {
+    const min = minInt(T);
+    const max = maxInt(T);
+
+    var runtime_val: T = undefined;
+
+    if (min > runtime_val) @compileError("analyzed impossible branch");
+    if (min <= runtime_val) {} else @compileError("analyzed impossible branch");
+    if (runtime_val < min) @compileError("analyzed impossible branch");
+    if (runtime_val >= min) {} else @compileError("analyzed impossible branch");
+
+    if (min - 1 > runtime_val) @compileError("analyzed impossible branch");
+    if (min - 1 >= runtime_val) @compileError("analyzed impossible branch");
+    if (min - 1 < runtime_val) {} else @compileError("analyzed impossible branch");
+    if (min - 1 <= runtime_val) {} else @compileError("analyzed impossible branch");
+    if (min - 1 == runtime_val) @compileError("analyzed impossible branch");
+    if (min - 1 != runtime_val) {} else @compileError("analyzed impossible branch");
+    if (runtime_val < min - 1) @compileError("analyzed impossible branch");
+    if (runtime_val <= min - 1) @compileError("analyzed impossible branch");
+    if (runtime_val > min - 1) {} else @compileError("analyzed impossible branch");
+    if (runtime_val >= min - 1) {} else @compileError("analyzed impossible branch");
+    if (runtime_val == min - 1) @compileError("analyzed impossible branch");
+    if (runtime_val != min - 1) {} else @compileError("analyzed impossible branch");
+
+    if (max >= runtime_val) {} else @compileError("analyzed impossible branch");
+    if (max < runtime_val) @compileError("analyzed impossible branch");
+    if (runtime_val <= max) {} else @compileError("analyzed impossible branch");
+    if (runtime_val > max) @compileError("analyzed impossible branch");
+
+    if (max + 1 > runtime_val) {} else @compileError("analyzed impossible branch");
+    if (max + 1 >= runtime_val) {} else @compileError("analyzed impossible branch");
+    if (max + 1 < runtime_val) @compileError("analyzed impossible branch");
+    if (max + 1 <= runtime_val) @compileError("analyzed impossible branch");
+    if (max + 1 == runtime_val) @compileError("analyzed impossible branch");
+    if (max + 1 != runtime_val) {} else @compileError("analyzed impossible branch");
+    if (runtime_val < max + 1) {} else @compileError("analyzed impossible branch");
+    if (runtime_val <= max + 1) {} else @compileError("analyzed impossible branch");
+    if (runtime_val > max + 1) @compileError("analyzed impossible branch");
+    if (runtime_val >= max + 1) @compileError("analyzed impossible branch");
+    if (runtime_val == max + 1) @compileError("analyzed impossible branch");
+    if (runtime_val != max + 1) {} else @compileError("analyzed impossible branch");
+
+    const undef_const: T = undefined;
+
+    if (min > undef_const) @compileError("analyzed impossible branch");
+    if (min <= undef_const) {} else @compileError("analyzed impossible branch");
+    if (undef_const < min) @compileError("analyzed impossible branch");
+    if (undef_const >= min) {} else @compileError("analyzed impossible branch");
+
+    if (min - 1 > undef_const) @compileError("analyzed impossible branch");
+    if (min - 1 >= undef_const) @compileError("analyzed impossible branch");
+    if (min - 1 < undef_const) {} else @compileError("analyzed impossible branch");
+    if (min - 1 <= undef_const) {} else @compileError("analyzed impossible branch");
+    if (min - 1 == undef_const) @compileError("analyzed impossible branch");
+    if (min - 1 != undef_const) {} else @compileError("analyzed impossible branch");
+    if (undef_const < min - 1) @compileError("analyzed impossible branch");
+    if (undef_const <= min - 1) @compileError("analyzed impossible branch");
+    if (undef_const > min - 1) {} else @compileError("analyzed impossible branch");
+    if (undef_const >= min - 1) {} else @compileError("analyzed impossible branch");
+    if (undef_const == min - 1) @compileError("analyzed impossible branch");
+    if (undef_const != min - 1) {} else @compileError("analyzed impossible branch");
+
+    if (max >= undef_const) {} else @compileError("analyzed impossible branch");
+    if (max < undef_const) @compileError("analyzed impossible branch");
+    if (undef_const <= max) {} else @compileError("analyzed impossible branch");
+    if (undef_const > max) @compileError("analyzed impossible branch");
+
+    if (max + 1 > undef_const) {} else @compileError("analyzed impossible branch");
+    if (max + 1 >= undef_const) {} else @compileError("analyzed impossible branch");
+    if (max + 1 < undef_const) @compileError("analyzed impossible branch");
+    if (max + 1 <= undef_const) @compileError("analyzed impossible branch");
+    if (max + 1 == undef_const) @compileError("analyzed impossible branch");
+    if (max + 1 != undef_const) {} else @compileError("analyzed impossible branch");
+    if (undef_const < max + 1) {} else @compileError("analyzed impossible branch");
+    if (undef_const <= max + 1) {} else @compileError("analyzed impossible branch");
+    if (undef_const > max + 1) @compileError("analyzed impossible branch");
+    if (undef_const >= max + 1) @compileError("analyzed impossible branch");
+    if (undef_const == max + 1) @compileError("analyzed impossible branch");
+    if (undef_const != max + 1) {} else @compileError("analyzed impossible branch");
+}
test/behavior.zig
@@ -158,7 +158,7 @@ test {
     _ = @import("behavior/incomplete_struct_param_tld.zig");
     _ = @import("behavior/inline_switch.zig");
     _ = @import("behavior/int128.zig");
-    _ = @import("behavior/int_div.zig");
+    _ = @import("behavior/int_comparison_elision.zig");
     _ = @import("behavior/inttoptr.zig");
     _ = @import("behavior/ir_block_deps.zig");
     _ = @import("behavior/math.zig");