Commit 6e78c007df

Andrew Kelley <andrew@ziglang.org>
2021-07-31 02:40:30
Sema: improved AIR when one operand of bool cmp is known
When doing `x == true` or `x == false` it is now lowered as either a no-op or a not, respectively, rather than a cmp instruction. This commit also extracts a zirCmpEq function out from zirCmp, reducing the amount of branching (on is_equality_cmp) in both functions.
1 parent 507dc1f
Changed files (1)
src/Sema.zig
@@ -193,12 +193,12 @@ pub fn analyzeBody(
             .call_compile_time            => try sema.zirCall(block, inst, .compile_time, false),
             .call_nosuspend               => try sema.zirCall(block, inst, .no_async, false),
             .call_async                   => try sema.zirCall(block, inst, .async_kw, false),
-            .cmp_eq                       => try sema.zirCmp(block, inst, .eq),
-            .cmp_gt                       => try sema.zirCmp(block, inst, .gt),
-            .cmp_gte                      => try sema.zirCmp(block, inst, .gte),
             .cmp_lt                       => try sema.zirCmp(block, inst, .lt),
             .cmp_lte                      => try sema.zirCmp(block, inst, .lte),
-            .cmp_neq                      => try sema.zirCmp(block, inst, .neq),
+            .cmp_eq                       => try sema.zirCmpEq(block, inst, .eq, .cmp_eq),
+            .cmp_gte                      => try sema.zirCmp(block, inst, .gte),
+            .cmp_gt                       => try sema.zirCmp(block, inst, .gt),
+            .cmp_neq                      => try sema.zirCmpEq(block, inst, .neq, .cmp_neq),
             .coerce_result_ptr            => try sema.zirCoerceResultPtr(block, inst),
             .decl_ref                     => try sema.zirDeclRef(block, inst),
             .decl_val                     => try sema.zirDeclVal(block, inst),
@@ -5040,17 +5040,18 @@ fn zirAsm(
     return asm_air;
 }
 
-fn zirCmp(
+/// Only called for equality operators. See also `zirCmp`.
+fn zirCmpEq(
     sema: *Sema,
     block: *Scope.Block,
     inst: Zir.Inst.Index,
     op: std.math.CompareOperator,
+    air_tag: Air.Inst.Tag,
 ) CompileError!Air.Inst.Ref {
     const tracy = trace(@src());
     defer tracy.end();
 
     const mod = sema.mod;
-
     const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
     const src: LazySrcLoc = inst_data.src();
@@ -5059,73 +5060,65 @@ fn zirCmp(
     const lhs = sema.resolveInst(extra.lhs);
     const rhs = sema.resolveInst(extra.rhs);
 
-    const is_equality_cmp = switch (op) {
-        .eq, .neq => true,
-        else => false,
-    };
     const lhs_ty = sema.typeOf(lhs);
     const rhs_ty = sema.typeOf(rhs);
     const lhs_ty_tag = lhs_ty.zigTypeTag();
     const rhs_ty_tag = rhs_ty.zigTypeTag();
-    if (is_equality_cmp and lhs_ty_tag == .Null and rhs_ty_tag == .Null) {
+    if (lhs_ty_tag == .Null and rhs_ty_tag == .Null) {
         // null == null, null != null
         if (op == .eq) {
             return Air.Inst.Ref.bool_true;
         } else {
             return Air.Inst.Ref.bool_false;
         }
-    } else if (is_equality_cmp and
-        ((lhs_ty_tag == .Null and rhs_ty_tag == .Optional) or
+    }
+    if (((lhs_ty_tag == .Null and rhs_ty_tag == .Optional) or
         rhs_ty_tag == .Null and lhs_ty_tag == .Optional))
     {
         // comparing null with optionals
         const opt_operand = if (lhs_ty_tag == .Optional) lhs else rhs;
         return sema.analyzeIsNull(block, src, opt_operand, op == .neq);
-    } else if (is_equality_cmp and
-        ((lhs_ty_tag == .Null and rhs_ty.isCPtr()) or (rhs_ty_tag == .Null and lhs_ty.isCPtr())))
-    {
+    }
+    if (((lhs_ty_tag == .Null and rhs_ty.isCPtr()) or (rhs_ty_tag == .Null and lhs_ty.isCPtr()))) {
         return mod.fail(&block.base, src, "TODO implement C pointer cmp", .{});
-    } else if (lhs_ty_tag == .Null or rhs_ty_tag == .Null) {
+    }
+    if (lhs_ty_tag == .Null or rhs_ty_tag == .Null) {
         const non_null_type = if (lhs_ty_tag == .Null) rhs_ty else lhs_ty;
         return mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type});
-    } else if (is_equality_cmp and
-        ((lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) or
+    }
+    if (((lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) or
         (rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union)))
     {
         return mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{});
-    } else if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) {
-        if (!is_equality_cmp) {
-            return mod.fail(&block.base, src, "{s} operator not allowed for errors", .{@tagName(op)});
-        }
-        if (try sema.resolveMaybeUndefVal(block, lhs_src, lhs)) |lval| {
-            if (try sema.resolveMaybeUndefVal(block, rhs_src, rhs)) |rval| {
-                if (lval.isUndef() or rval.isUndef()) {
-                    return sema.addConstUndef(Type.initTag(.bool));
-                }
-                // TODO optimisation opportunity: evaluate if mem.eql is faster with the names,
-                // or calling to Module.getErrorValue to get the values and then compare them is
-                // faster.
-                const lhs_name = lval.castTag(.@"error").?.data.name;
-                const rhs_name = rval.castTag(.@"error").?.data.name;
-                if (mem.eql(u8, lhs_name, rhs_name) == (op == .eq)) {
-                    return Air.Inst.Ref.bool_true;
+    }
+    if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) {
+        const runtime_src: LazySrcLoc = src: {
+            if (try sema.resolveMaybeUndefVal(block, lhs_src, lhs)) |lval| {
+                if (try sema.resolveMaybeUndefVal(block, rhs_src, rhs)) |rval| {
+                    if (lval.isUndef() or rval.isUndef()) {
+                        return sema.addConstUndef(Type.initTag(.bool));
+                    }
+                    // TODO optimisation opportunity: evaluate if mem.eql is faster with the names,
+                    // or calling to Module.getErrorValue to get the values and then compare them is
+                    // faster.
+                    const lhs_name = lval.castTag(.@"error").?.data.name;
+                    const rhs_name = rval.castTag(.@"error").?.data.name;
+                    if (mem.eql(u8, lhs_name, rhs_name) == (op == .eq)) {
+                        return Air.Inst.Ref.bool_true;
+                    } else {
+                        return Air.Inst.Ref.bool_false;
+                    }
                 } else {
-                    return Air.Inst.Ref.bool_false;
+                    break :src rhs_src;
                 }
+            } else {
+                break :src lhs_src;
             }
-        }
-        try sema.requireRuntimeBlock(block, src);
-        const tag: Air.Inst.Tag = if (op == .eq) .cmp_eq else .cmp_neq;
-        return block.addBinOp(tag, lhs, rhs);
-    } else if (lhs_ty.isNumeric() and rhs_ty.isNumeric()) {
-        // This operation allows any combination of integer and float types, regardless of the
-        // signed-ness, comptime-ness, and bit-width. So peer type resolution is incorrect for
-        // numeric types.
-        return sema.cmpNumeric(block, src, lhs, rhs, op, lhs_src, rhs_src);
-    } else if (lhs_ty_tag == .Type and rhs_ty_tag == .Type) {
-        if (!is_equality_cmp) {
-            return mod.fail(&block.base, src, "{s} operator not allowed for types", .{@tagName(op)});
-        }
+        };
+        try sema.requireRuntimeBlock(block, runtime_src);
+        return block.addBinOp(air_tag, lhs, rhs);
+    }
+    if (lhs_ty_tag == .Type and rhs_ty_tag == .Type) {
         const lhs_as_type = try sema.analyzeAsType(block, lhs_src, lhs);
         const rhs_as_type = try sema.analyzeAsType(block, rhs_src, rhs);
         if (lhs_as_type.eql(rhs_as_type) == (op == .eq)) {
@@ -5134,11 +5127,54 @@ fn zirCmp(
             return Air.Inst.Ref.bool_false;
         }
     }
+    return sema.analyzeCmp(block, src, lhs, rhs, op, lhs_src, rhs_src, true);
+}
+
+/// Only called for non-equality operators. See also `zirCmpEq`.
+fn zirCmp(
+    sema: *Sema,
+    block: *Scope.Block,
+    inst: Zir.Inst.Index,
+    op: std.math.CompareOperator,
+) CompileError!Air.Inst.Ref {
+    const tracy = trace(@src());
+    defer tracy.end();
 
+    const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
+    const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
+    const src: LazySrcLoc = inst_data.src();
+    const lhs_src: LazySrcLoc = .{ .node_offset_bin_lhs = inst_data.src_node };
+    const rhs_src: LazySrcLoc = .{ .node_offset_bin_rhs = inst_data.src_node };
+    const lhs = sema.resolveInst(extra.lhs);
+    const rhs = sema.resolveInst(extra.rhs);
+    return sema.analyzeCmp(block, src, lhs, rhs, op, lhs_src, rhs_src, false);
+}
+
+fn analyzeCmp(
+    sema: *Sema,
+    block: *Scope.Block,
+    src: LazySrcLoc,
+    lhs: Air.Inst.Ref,
+    rhs: Air.Inst.Ref,
+    op: std.math.CompareOperator,
+    lhs_src: LazySrcLoc,
+    rhs_src: LazySrcLoc,
+    is_equality_cmp: bool,
+) CompileError!Air.Inst.Ref {
+    const lhs_ty = sema.typeOf(lhs);
+    const rhs_ty = sema.typeOf(rhs);
+    if (lhs_ty.isNumeric() and rhs_ty.isNumeric()) {
+        // This operation allows any combination of integer and float types, regardless of the
+        // signed-ness, comptime-ness, and bit-width. So peer type resolution is incorrect for
+        // numeric types.
+        return sema.cmpNumeric(block, src, lhs, rhs, op, lhs_src, rhs_src);
+    }
     const instructions = &[_]Air.Inst.Ref{ lhs, rhs };
     const resolved_type = try sema.resolvePeerTypes(block, src, instructions);
     if (!resolved_type.isSelfComparable(is_equality_cmp)) {
-        return mod.fail(&block.base, src, "operator not allowed for type '{}'", .{resolved_type});
+        return sema.mod.fail(&block.base, src, "{s} operator not allowed for type '{}'", .{
+            @tagName(op), resolved_type,
+        });
     }
 
     const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src);
@@ -5146,19 +5182,31 @@ fn zirCmp(
 
     const runtime_src: LazySrcLoc = src: {
         if (try sema.resolveMaybeUndefVal(block, lhs_src, casted_lhs)) |lhs_val| {
+            if (lhs_val.isUndef()) return sema.addConstUndef(resolved_type);
             if (try sema.resolveMaybeUndefVal(block, rhs_src, casted_rhs)) |rhs_val| {
-                if (lhs_val.isUndef() or rhs_val.isUndef()) {
-                    return sema.addConstUndef(resolved_type);
-                }
+                if (rhs_val.isUndef()) return sema.addConstUndef(resolved_type);
+
                 if (lhs_val.compare(op, rhs_val, resolved_type)) {
                     return Air.Inst.Ref.bool_true;
                 } else {
                     return Air.Inst.Ref.bool_false;
                 }
             } else {
+                if (resolved_type.zigTypeTag() == .Bool) {
+                    // We can lower bool eq/neq more efficiently.
+                    return sema.runtimeBoolCmp(block, op, casted_rhs, lhs_val.toBool(), rhs_src);
+                }
                 break :src rhs_src;
             }
         } else {
+            // For bools, we still check the other operand, because we can lower
+            // bool eq/neq more efficiently.
+            if (resolved_type.zigTypeTag() == .Bool) {
+                if (try sema.resolveMaybeUndefVal(block, rhs_src, casted_rhs)) |rhs_val| {
+                    if (rhs_val.isUndef()) return sema.addConstUndef(resolved_type);
+                    return sema.runtimeBoolCmp(block, op, casted_lhs, rhs_val.toBool(), lhs_src);
+                }
+            }
             break :src lhs_src;
         }
     };
@@ -5176,6 +5224,26 @@ fn zirCmp(
     return block.addBinOp(tag, casted_lhs, casted_rhs);
 }
 
+/// cmp_eq (x, false) => not(x)
+/// cmp_eq (x, true ) => x
+/// cmp_neq(x, false) => x
+/// cmp_neq(x, true ) => not(x)
+fn runtimeBoolCmp(
+    sema: *Sema,
+    block: *Scope.Block,
+    op: std.math.CompareOperator,
+    lhs: Air.Inst.Ref,
+    rhs: bool,
+    runtime_src: LazySrcLoc,
+) CompileError!Air.Inst.Ref {
+    if ((op == .neq) == rhs) {
+        try sema.requireRuntimeBlock(block, runtime_src);
+        return block.addTyOp(.not, Type.initTag(.bool), lhs);
+    } else {
+        return lhs;
+    }
+}
+
 fn zirSizeOf(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
     const inst_data = sema.code.instructions.items(.data)[inst].un_node;
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };