Commit 18119aae30

Andrew Kelley <andrew@ziglang.org>
2021-04-07 21:15:05
Sema: implement comparison analysis for non-numeric types
1 parent d9c25ec
Changed files (3)
src
test
stage2
src/Sema.zig
@@ -3776,9 +3776,13 @@ fn zirCmp(
     const tracy = trace(@src());
     defer tracy.end();
 
+    const mod = sema.mod;
+
     const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
     const extra = sema.code.extraData(zir.Inst.Bin, inst_data.payload_index).data;
     const src: LazySrcLoc = inst_data.src();
+    const lhs_src: LazySrcLoc = .{ .node_offset_bin_lhs = inst_data.src_node };
+    const rhs_src: LazySrcLoc = .{ .node_offset_bin_rhs = inst_data.src_node };
     const lhs = try sema.resolveInst(extra.lhs);
     const rhs = try sema.resolveInst(extra.rhs);
 
@@ -3790,7 +3794,7 @@ fn zirCmp(
     const rhs_ty_tag = rhs.ty.zigTypeTag();
     if (is_equality_cmp and lhs_ty_tag == .Null and rhs_ty_tag == .Null) {
         // null == null, null != null
-        return sema.mod.constBool(sema.arena, src, op == .eq);
+        return mod.constBool(sema.arena, src, op == .eq);
     } else if (is_equality_cmp and
         ((lhs_ty_tag == .Null and rhs_ty_tag == .Optional) or
         rhs_ty_tag == .Null and lhs_ty_tag == .Optional))
@@ -3801,23 +3805,23 @@ fn zirCmp(
     } else if (is_equality_cmp and
         ((lhs_ty_tag == .Null and rhs.ty.isCPtr()) or (rhs_ty_tag == .Null and lhs.ty.isCPtr())))
     {
-        return sema.mod.fail(&block.base, src, "TODO implement C pointer cmp", .{});
+        return mod.fail(&block.base, src, "TODO implement C pointer cmp", .{});
     } else if (lhs_ty_tag == .Null or rhs_ty_tag == .Null) {
         const non_null_type = if (lhs_ty_tag == .Null) rhs.ty else lhs.ty;
-        return sema.mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type});
+        return mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type});
     } else if (is_equality_cmp and
         ((lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) or
         (rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union)))
     {
-        return sema.mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{});
+        return mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{});
     } else if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) {
         if (!is_equality_cmp) {
-            return sema.mod.fail(&block.base, src, "{s} operator not allowed for errors", .{@tagName(op)});
+            return mod.fail(&block.base, src, "{s} operator not allowed for errors", .{@tagName(op)});
         }
         if (rhs.value()) |rval| {
             if (lhs.value()) |lval| {
                 // TODO optimisation oppurtunity: evaluate if std.mem.eql is faster with the names, or calling to Module.getErrorValue to get the values and then compare them is faster
-                return sema.mod.constBool(sema.arena, src, std.mem.eql(u8, lval.castTag(.@"error").?.data.name, rval.castTag(.@"error").?.data.name) == (op == .eq));
+                return mod.constBool(sema.arena, src, std.mem.eql(u8, lval.castTag(.@"error").?.data.name, rval.castTag(.@"error").?.data.name) == (op == .eq));
             }
         }
         try sema.requireRuntimeBlock(block, src);
@@ -3829,11 +3833,30 @@ fn zirCmp(
         return sema.cmpNumeric(block, src, lhs, rhs, op);
     } else if (lhs_ty_tag == .Type and rhs_ty_tag == .Type) {
         if (!is_equality_cmp) {
-            return sema.mod.fail(&block.base, src, "{s} operator not allowed for types", .{@tagName(op)});
+            return mod.fail(&block.base, src, "{s} operator not allowed for types", .{@tagName(op)});
         }
-        return sema.mod.constBool(sema.arena, src, lhs.value().?.eql(rhs.value().?) == (op == .eq));
+        return mod.constBool(sema.arena, src, lhs.value().?.eql(rhs.value().?) == (op == .eq));
+    }
+
+    const instructions = &[_]*Inst{ lhs, rhs };
+    const resolved_type = try sema.resolvePeerTypes(block, src, instructions);
+    if (!resolved_type.isSelfComparable(is_equality_cmp)) {
+        return mod.fail(&block.base, src, "operator not allowed for type '{}'", .{resolved_type});
     }
-    return sema.mod.fail(&block.base, src, "TODO implement more cmp analysis", .{});
+
+    const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src);
+    const casted_rhs = try sema.coerce(block, resolved_type, rhs, rhs_src);
+    try sema.requireRuntimeBlock(block, src); // TODO try to do it at comptime
+    const bool_type = Type.initTag(.bool); // TODO handle vectors
+    const tag: Inst.Tag = switch (op) {
+        .lt => .cmp_lt,
+        .lte => .cmp_lte,
+        .eq => .cmp_eq,
+        .gte => .cmp_gte,
+        .gt => .cmp_gt,
+        .neq => .cmp_neq,
+    };
+    return block.addBinOp(src, bool_type, tag, casted_lhs, casted_rhs);
 }
 
 fn zirTypeof(sema: *Sema, block: *Scope.Block, inst: zir.Inst.Index) InnerError!*Inst {
src/type.zig
@@ -107,6 +107,42 @@ pub const Type = extern union {
         }
     }
 
+    pub fn isSelfComparable(ty: Type, is_equality_cmp: bool) bool {
+        return switch (ty.zigTypeTag()) {
+            .Int,
+            .Float,
+            .ComptimeFloat,
+            .ComptimeInt,
+            .Vector, // TODO some vectors require is_equality_cmp==true
+            => true,
+
+            .Bool,
+            .Type,
+            .Void,
+            .ErrorSet,
+            .Fn,
+            .BoundFn,
+            .Opaque,
+            .AnyFrame,
+            .Enum,
+            .EnumLiteral,
+            => is_equality_cmp,
+
+            .NoReturn,
+            .Array,
+            .Struct,
+            .Undefined,
+            .Null,
+            .ErrorUnion,
+            .Union,
+            .Frame,
+            => false,
+
+            .Pointer => is_equality_cmp or ty.isCPtr(),
+            .Optional => is_equality_cmp and ty.isAbiPtr(),
+        };
+    }
+
     pub fn initTag(comptime small_tag: Tag) Type {
         comptime assert(@enumToInt(small_tag) < Tag.no_payload_count);
         return .{ .tag_if_small_enough = @enumToInt(small_tag) };
@@ -1583,6 +1619,11 @@ pub const Type = extern union {
         }
     }
 
+    /// Returns whether the type is represented as a pointer in the ABI.
+    pub fn isAbiPtr(self: Type) bool {
+        @panic("TODO implement this");
+    }
+
     /// Asserts that the type is an error union.
     pub fn errorUnionChild(self: Type) Type {
         return switch (self.tag()) {
test/stage2/cbe.zig
@@ -536,6 +536,27 @@ pub fn addCases(ctx: *TestContext) !void {
         , "");
     }
 
+    {
+        var case = ctx.exeFromCompiledC("enums", .{});
+        case.addCompareOutput(
+            \\const Number = enum { One, Two, Three };
+            \\
+            \\export fn main() c_int {
+            \\    var number1 = Number.One;
+            \\    var number2: Number = .Two;
+            \\    const number3 = @intToEnum(Number, 2);
+            \\    if (number1 == number2) return 1;
+            \\    if (number2 == number3) return 1;
+            \\    if (@enumToInt(number1) != 0) return 1;
+            \\    if (@enumToInt(number2) != 1) return 1;
+            \\    if (@enumToInt(number3) != 2) return 1;
+            \\    var x: Number = .Two;
+            \\    if (number2 != x) return 1;
+            \\    return 0;
+            \\}
+        , "");
+    }
+
     ctx.c("empty start function", linux_x64,
         \\export fn _start() noreturn {
         \\    unreachable;