Commit 18119aae30
src/Sema.zig
@@ -3776,9 +3776,13 @@ fn zirCmp(
const tracy = trace(@src());
defer tracy.end();
+ const mod = sema.mod;
+
const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
const extra = sema.code.extraData(zir.Inst.Bin, inst_data.payload_index).data;
const src: LazySrcLoc = inst_data.src();
+ const lhs_src: LazySrcLoc = .{ .node_offset_bin_lhs = inst_data.src_node };
+ const rhs_src: LazySrcLoc = .{ .node_offset_bin_rhs = inst_data.src_node };
const lhs = try sema.resolveInst(extra.lhs);
const rhs = try sema.resolveInst(extra.rhs);
@@ -3790,7 +3794,7 @@ fn zirCmp(
const rhs_ty_tag = rhs.ty.zigTypeTag();
if (is_equality_cmp and lhs_ty_tag == .Null and rhs_ty_tag == .Null) {
// null == null, null != null
- return sema.mod.constBool(sema.arena, src, op == .eq);
+ return mod.constBool(sema.arena, src, op == .eq);
} else if (is_equality_cmp and
((lhs_ty_tag == .Null and rhs_ty_tag == .Optional) or
rhs_ty_tag == .Null and lhs_ty_tag == .Optional))
@@ -3801,23 +3805,23 @@ fn zirCmp(
} else if (is_equality_cmp and
((lhs_ty_tag == .Null and rhs.ty.isCPtr()) or (rhs_ty_tag == .Null and lhs.ty.isCPtr())))
{
- return sema.mod.fail(&block.base, src, "TODO implement C pointer cmp", .{});
+ return mod.fail(&block.base, src, "TODO implement C pointer cmp", .{});
} else if (lhs_ty_tag == .Null or rhs_ty_tag == .Null) {
const non_null_type = if (lhs_ty_tag == .Null) rhs.ty else lhs.ty;
- return sema.mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type});
+ return mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type});
} else if (is_equality_cmp and
((lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) or
(rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union)))
{
- return sema.mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{});
+ return mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{});
} else if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) {
if (!is_equality_cmp) {
- return sema.mod.fail(&block.base, src, "{s} operator not allowed for errors", .{@tagName(op)});
+ return mod.fail(&block.base, src, "{s} operator not allowed for errors", .{@tagName(op)});
}
if (rhs.value()) |rval| {
if (lhs.value()) |lval| {
// TODO optimisation oppurtunity: evaluate if std.mem.eql is faster with the names, or calling to Module.getErrorValue to get the values and then compare them is faster
- return sema.mod.constBool(sema.arena, src, std.mem.eql(u8, lval.castTag(.@"error").?.data.name, rval.castTag(.@"error").?.data.name) == (op == .eq));
+ return mod.constBool(sema.arena, src, std.mem.eql(u8, lval.castTag(.@"error").?.data.name, rval.castTag(.@"error").?.data.name) == (op == .eq));
}
}
try sema.requireRuntimeBlock(block, src);
@@ -3829,11 +3833,30 @@ fn zirCmp(
return sema.cmpNumeric(block, src, lhs, rhs, op);
} else if (lhs_ty_tag == .Type and rhs_ty_tag == .Type) {
if (!is_equality_cmp) {
- return sema.mod.fail(&block.base, src, "{s} operator not allowed for types", .{@tagName(op)});
+ return mod.fail(&block.base, src, "{s} operator not allowed for types", .{@tagName(op)});
}
- return sema.mod.constBool(sema.arena, src, lhs.value().?.eql(rhs.value().?) == (op == .eq));
+ return mod.constBool(sema.arena, src, lhs.value().?.eql(rhs.value().?) == (op == .eq));
+ }
+
+ const instructions = &[_]*Inst{ lhs, rhs };
+ const resolved_type = try sema.resolvePeerTypes(block, src, instructions);
+ if (!resolved_type.isSelfComparable(is_equality_cmp)) {
+ return mod.fail(&block.base, src, "operator not allowed for type '{}'", .{resolved_type});
}
- return sema.mod.fail(&block.base, src, "TODO implement more cmp analysis", .{});
+
+ const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src);
+ const casted_rhs = try sema.coerce(block, resolved_type, rhs, rhs_src);
+ try sema.requireRuntimeBlock(block, src); // TODO try to do it at comptime
+ const bool_type = Type.initTag(.bool); // TODO handle vectors
+ const tag: Inst.Tag = switch (op) {
+ .lt => .cmp_lt,
+ .lte => .cmp_lte,
+ .eq => .cmp_eq,
+ .gte => .cmp_gte,
+ .gt => .cmp_gt,
+ .neq => .cmp_neq,
+ };
+ return block.addBinOp(src, bool_type, tag, casted_lhs, casted_rhs);
}
fn zirTypeof(sema: *Sema, block: *Scope.Block, inst: zir.Inst.Index) InnerError!*Inst {
src/type.zig
@@ -107,6 +107,42 @@ pub const Type = extern union {
}
}
+ pub fn isSelfComparable(ty: Type, is_equality_cmp: bool) bool {
+ return switch (ty.zigTypeTag()) {
+ .Int,
+ .Float,
+ .ComptimeFloat,
+ .ComptimeInt,
+ .Vector, // TODO some vectors require is_equality_cmp==true
+ => true,
+
+ .Bool,
+ .Type,
+ .Void,
+ .ErrorSet,
+ .Fn,
+ .BoundFn,
+ .Opaque,
+ .AnyFrame,
+ .Enum,
+ .EnumLiteral,
+ => is_equality_cmp,
+
+ .NoReturn,
+ .Array,
+ .Struct,
+ .Undefined,
+ .Null,
+ .ErrorUnion,
+ .Union,
+ .Frame,
+ => false,
+
+ .Pointer => is_equality_cmp or ty.isCPtr(),
+ .Optional => is_equality_cmp and ty.isAbiPtr(),
+ };
+ }
+
pub fn initTag(comptime small_tag: Tag) Type {
comptime assert(@enumToInt(small_tag) < Tag.no_payload_count);
return .{ .tag_if_small_enough = @enumToInt(small_tag) };
@@ -1583,6 +1619,11 @@ pub const Type = extern union {
}
}
+ /// Returns whether the type is represented as a pointer in the ABI.
+ pub fn isAbiPtr(self: Type) bool {
+ @panic("TODO implement this");
+ }
+
/// Asserts that the type is an error union.
pub fn errorUnionChild(self: Type) Type {
return switch (self.tag()) {
test/stage2/cbe.zig
@@ -536,6 +536,27 @@ pub fn addCases(ctx: *TestContext) !void {
, "");
}
+ {
+ var case = ctx.exeFromCompiledC("enums", .{});
+ case.addCompareOutput(
+ \\const Number = enum { One, Two, Three };
+ \\
+ \\export fn main() c_int {
+ \\ var number1 = Number.One;
+ \\ var number2: Number = .Two;
+ \\ const number3 = @intToEnum(Number, 2);
+ \\ if (number1 == number2) return 1;
+ \\ if (number2 == number3) return 1;
+ \\ if (@enumToInt(number1) != 0) return 1;
+ \\ if (@enumToInt(number2) != 1) return 1;
+ \\ if (@enumToInt(number3) != 2) return 1;
+ \\ var x: Number = .Two;
+ \\ if (number2 != x) return 1;
+ \\ return 0;
+ \\}
+ , "");
+ }
+
ctx.c("empty start function", linux_x64,
\\export fn _start() noreturn {
\\ unreachable;