Commit 569870ca41

Mitchell Hashimoto <mitchell.hashimoto@gmail.com>
2022-03-09 05:44:58
stage2: error_set_merged type equality
This implements type equality for error sets. This is done through element-wise error set comparison. Inferred error sets are always distinct types and other error sets are always sorted. See #11022.
1 parent 0b82c02
src/Module.zig
@@ -824,7 +824,7 @@ pub const ErrorSet = struct {
     /// Offset from Decl node index, points to the error set AST node.
     node_offset: i32,
     /// The string bytes are stored in the owner Decl arena.
-    /// They are in the same order they appear in the AST.
+    /// These must be in sorted order. See sortNames.
     names: NameMap,
 
     pub const NameMap = std.StringArrayHashMapUnmanaged(void);
@@ -836,6 +836,18 @@ pub const ErrorSet = struct {
             .lazy = .{ .node_offset = self.node_offset },
         };
     }
+
+    /// sort the NameMap. This should be called whenever the map is modified.
+    /// alloc should be the allocator used for the NameMap data.
+    pub fn sortNames(names: *NameMap) void {
+        const Context = struct {
+            keys: [][]const u8,
+            pub fn lessThan(ctx: @This(), a_index: usize, b_index: usize) bool {
+                return std.mem.lessThan(u8, ctx.keys[a_index], ctx.keys[b_index]);
+            }
+        };
+        names.sort(Context{ .keys = names.keys() });
+    }
 };
 
 pub const RequiresComptime = enum { no, yes, unknown, wip };
src/Sema.zig
@@ -2212,6 +2212,10 @@ fn zirErrorSetDecl(
             return sema.fail(block, src, "duplicate error set field {s}", .{name});
         }
     }
+
+    // names must be sorted.
+    Module.ErrorSet.sortNames(&names);
+
     error_set.* = .{
         .owner_decl = new_decl,
         .node_offset = inst_data.src_node,
src/type.zig
@@ -564,27 +564,30 @@ pub const Type = extern union {
             => {
                 if (b.zigTypeTag() != .ErrorSet) return false;
 
-                // TODO: revisit the language specification for how to evaluate equality
-                // for error set types.
-
-                if (a.tag() == .anyerror and b.tag() == .anyerror) {
-                    return true;
+                // inferred error sets are only equal if both are inferred
+                // and they originate from the exact same function.
+                if (a.castTag(.error_set_inferred)) |a_pl| {
+                    if (b.castTag(.error_set_inferred)) |b_pl| {
+                        return a_pl.data.func == b_pl.data.func;
+                    }
+                    return false;
                 }
-
-                if (a.tag() == .error_set and b.tag() == .error_set) {
-                    return a.castTag(.error_set).?.data.owner_decl == b.castTag(.error_set).?.data.owner_decl;
+                if (b.tag() == .error_set_inferred) return false;
+
+                // anyerror matches exactly.
+                const a_is_any = a.isAnyError();
+                const b_is_any = b.isAnyError();
+                if (a_is_any or b_is_any) return a_is_any and b_is_any;
+
+                // two resolved sets match if their error set names match.
+                const a_set = a.errorSetNames();
+                const b_set = b.errorSetNames();
+                if (a_set.len != b_set.len) return false;
+                for (b_set) |b_val| {
+                    if (!a.errorSetHasField(b_val)) return false;
                 }
 
-                if (a.tag() == .error_set_inferred and b.tag() == .error_set_inferred) {
-                    return a.castTag(.error_set_inferred).?.data == b.castTag(.error_set_inferred).?.data;
-                }
-
-                if (a.tag() == .error_set_single and b.tag() == .error_set_single) {
-                    const a_data = a.castTag(.error_set_single).?.data;
-                    const b_data = b.castTag(.error_set_single).?.data;
-                    return std.mem.eql(u8, a_data, b_data);
-                }
-                return false;
+                return true;
             },
 
             .@"opaque" => {
@@ -961,12 +964,30 @@ pub const Type = extern union {
 
             .error_set,
             .error_set_single,
-            .anyerror,
-            .error_set_inferred,
             .error_set_merged,
             => {
+                // all are treated like an "error set" for hashing
+                std.hash.autoHash(hasher, std.builtin.TypeId.ErrorSet);
+                std.hash.autoHash(hasher, Tag.error_set);
+
+                const names = ty.errorSetNames();
+                std.hash.autoHash(hasher, names.len);
+                assert(std.sort.isSorted([]const u8, names, u8, std.mem.lessThan));
+                for (names) |name| hasher.update(name);
+            },
+
+            .anyerror => {
+                // anyerror is distinct from other error sets
                 std.hash.autoHash(hasher, std.builtin.TypeId.ErrorSet);
-                // TODO implement this after revisiting Type.Eql for error sets
+                std.hash.autoHash(hasher, Tag.anyerror);
+            },
+
+            .error_set_inferred => {
+                // inferred error sets are compared using their data pointer
+                const data = ty.castTag(.error_set_inferred).?.data.func;
+                std.hash.autoHash(hasher, std.builtin.TypeId.ErrorSet);
+                std.hash.autoHash(hasher, Tag.error_set_inferred);
+                std.hash.autoHash(hasher, data);
             },
 
             .@"opaque" => {
@@ -4365,6 +4386,9 @@ pub const Type = extern union {
             try names.put(arena, name, {});
         }
 
+        // names must be sorted
+        Module.ErrorSet.sortNames(&names);
+
         return try Tag.error_set_merged.create(arena, names);
     }
 
src/value.zig
@@ -1870,6 +1870,16 @@ pub const Value = extern union {
 
                 return eql(a_payload.container_ptr, b_payload.container_ptr, ty);
             },
+            .@"error" => {
+                const a_name = a.castTag(.@"error").?.data.name;
+                const b_name = b.castTag(.@"error").?.data.name;
+                return std.mem.eql(u8, a_name, b_name);
+            },
+            .eu_payload => {
+                const a_payload = a.castTag(.eu_payload).?.data;
+                const b_payload = b.castTag(.eu_payload).?.data;
+                return eql(a_payload, b_payload, ty.errorUnionPayload());
+            },
             .eu_payload_ptr => @panic("TODO: Implement more pointer eql cases"),
             .opt_payload_ptr => @panic("TODO: Implement more pointer eql cases"),
             .array => {
test/behavior/cast.zig
@@ -669,8 +669,8 @@ test "peer type resolution: disjoint error sets" {
         try expect(error_set_info == .ErrorSet);
         try expect(error_set_info.ErrorSet.?.len == 3);
         try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One"));
-        try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Two"));
-        try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Three"));
+        try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three"));
+        try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two"));
     }
 
     {
@@ -678,8 +678,8 @@ test "peer type resolution: disjoint error sets" {
         const error_set_info = @typeInfo(ty);
         try expect(error_set_info == .ErrorSet);
         try expect(error_set_info.ErrorSet.?.len == 3);
-        try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "Three"));
-        try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "One"));
+        try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One"));
+        try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three"));
         try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two"));
     }
 }
@@ -704,8 +704,8 @@ test "peer type resolution: error union and error set" {
 
         const error_set_info = @typeInfo(info.ErrorUnion.error_set);
         try expect(error_set_info.ErrorSet.?.len == 3);
-        try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "Three"));
-        try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "One"));
+        try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One"));
+        try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three"));
         try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two"));
     }
 
@@ -717,8 +717,8 @@ test "peer type resolution: error union and error set" {
         const error_set_info = @typeInfo(info.ErrorUnion.error_set);
         try expect(error_set_info.ErrorSet.?.len == 3);
         try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One"));
-        try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Two"));
-        try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Three"));
+        try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three"));
+        try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two"));
     }
 }
 
test/behavior/error.zig
@@ -330,7 +330,11 @@ fn intLiteral(str: []const u8) !?i64 {
 }
 
 test "nested error union function call in optional unwrap" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
 
     const S = struct {
         const Foo = struct {
@@ -375,7 +379,11 @@ test "nested error union function call in optional unwrap" {
 }
 
 test "return function call to error set from error union function" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
 
     const S = struct {
         fn errorable() anyerror!i32 {
@@ -404,7 +412,11 @@ test "optional error set is the same size as error set" {
 }
 
 test "nested catch" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
 
     const S = struct {
         fn entry() !void {
@@ -428,11 +440,18 @@ test "nested catch" {
 }
 
 test "function pointer with return type that is error union with payload which is pointer of parent struct" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
+    // This test uses the stage2 const fn pointer
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
 
     const S = struct {
         const Foo = struct {
-            fun: fn (a: i32) (anyerror!*Foo),
+            fun: *const fn (a: i32) (anyerror!*Foo),
         };
 
         const Err = error{UnspecifiedErr};
@@ -480,7 +499,11 @@ test "return result loc as peer result loc in inferred error set function" {
 }
 
 test "error payload type is correctly resolved" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
 
     const MyIntWrapper = struct {
         const Self = @This();
@@ -496,7 +519,11 @@ test "error payload type is correctly resolved" {
 }
 
 test "error union comptime caching" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
 
     const S = struct {
         fn quux(comptime arg: anytype) void {
@@ -539,3 +566,69 @@ test "@errorName sentinel length matches slice length" {
 pub fn testBuiltinErrorName(err: anyerror) [:0]const u8 {
     return @errorName(err);
 }
+
+test "error set equality" {
+    // This tests using stage2 logic (#11022)
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
+
+    const a = error{One};
+    const b = error{One};
+
+    try expect(a == a);
+    try expect(a == b);
+    try expect(a == error{One});
+
+    // should treat as a set
+    const c = error{ One, Two };
+    const d = error{ Two, One };
+
+    try expect(c == d);
+}
+
+test "inferred error set equality" {
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
+
+    const S = struct {
+        fn foo() !void {
+            return @This().bar();
+        }
+
+        fn bar() !void {
+            return error.Bad;
+        }
+
+        fn baz() !void {
+            return quux();
+        }
+
+        fn quux() anyerror!void {}
+    };
+
+    const FooError = @typeInfo(@typeInfo(@TypeOf(S.foo)).Fn.return_type.?).ErrorUnion.error_set;
+    const BarError = @typeInfo(@typeInfo(@TypeOf(S.bar)).Fn.return_type.?).ErrorUnion.error_set;
+    const BazError = @typeInfo(@typeInfo(@TypeOf(S.baz)).Fn.return_type.?).ErrorUnion.error_set;
+
+    try expect(BarError != error{Bad});
+
+    try expect(FooError != anyerror);
+    try expect(BarError != anyerror);
+    try expect(BazError != anyerror);
+
+    try expect(FooError != BarError);
+    try expect(FooError != BazError);
+    try expect(BarError != BazError);
+
+    try expect(FooError == FooError);
+    try expect(BarError == BarError);
+    try expect(BazError == BazError);
+}
test/behavior/type_info.zig
@@ -205,6 +205,9 @@ test "type info: error set single value" {
 }
 
 test "type info: error set merged" {
+    // #11022 forces ordering of error sets in stage2
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
@@ -217,8 +220,8 @@ test "type info: error set merged" {
     try expect(error_set_info == .ErrorSet);
     try expect(error_set_info.ErrorSet.?.len == 3);
     try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One"));
-    try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Two"));
-    try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Three"));
+    try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three"));
+    try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two"));
 }
 
 test "type info: enum info" {
test/stage2/x86_64.zig
@@ -1412,25 +1412,6 @@ pub fn addCases(ctx: *TestContext) !void {
             });
         }
 
-        {
-            var case = ctx.exe("error set equality", target);
-
-            case.addCompareOutput(
-                \\pub fn main() void {
-                \\    assert(@TypeOf(error.Foo) == @TypeOf(error.Foo));
-                \\    assert(@TypeOf(error.Bar) != @TypeOf(error.Foo));
-                \\    assert(anyerror == anyerror);
-                \\    assert(error{Foo} != error{Foo});
-                \\    // TODO put inferred error sets here when @typeInfo works
-                \\}
-                \\fn assert(b: bool) void {
-                \\    if (!b) unreachable;
-                \\}
-            ,
-                "",
-            );
-        }
-
         {
             var case = ctx.exe("comptime var", target);