Commit b2343e63bd

Robin Voetter <robin@voetter.nl>
2021-12-16 02:23:15
stage2: move inferred error set state into func
1 parent cd733ce
Changed files (4)
src/codegen/c.zig
@@ -722,7 +722,7 @@ pub const DeclGen = struct {
         try bw.writeAll(" payload; uint16_t error; } ");
         const name_index = buffer.items.len;
         if (err_set_type.castTag(.error_set_inferred)) |inf_err_set_payload| {
-            const func = inf_err_set_payload.data.func;
+            const func = inf_err_set_payload.data;
             try bw.writeAll("zig_E_");
             try dg.renderDeclName(func.owner_decl, bw);
             try bw.writeAll(";\n");
src/Module.zig
@@ -1207,6 +1207,24 @@ pub const Fn = struct {
     is_cold: bool = false,
     is_noinline: bool = false,
 
+    /// These fields are used to keep track of any dependencies related to functions
+    /// that return inferred error sets. It's values are not used when the function
+    /// does not return an inferred error set.
+    inferred_error_set: struct {
+        /// All currently known errors that this function returns. This includes direct additions
+        /// via `return error.Foo;`, and possibly also errors that are returned from any dependent functions.
+        /// When the inferred error set is fully resolved, this map contains all the errors that the function might return.
+        errors: std.StringHashMapUnmanaged(void) = .{},
+
+        /// Other functions with inferred error sets which the inferred error set of this
+        /// function should include.
+        functions: std.AutoHashMapUnmanaged(*Fn, void) = .{},
+
+        /// Whether the function returned anyerror. This is true if either of the dependent functions
+        /// returns anyerror.
+        is_anyerror: bool = false,
+    } = .{},
+
     pub const Analysis = enum {
         queued,
         /// This function intentionally only has ZIR generated because it is marked
@@ -1222,23 +1240,37 @@ pub const Fn = struct {
     };
 
     pub fn deinit(func: *Fn, gpa: Allocator) void {
-        if (func.getInferredErrorSet()) |error_set_data| {
-            error_set_data.map.deinit(gpa);
-            error_set_data.functions.deinit(gpa);
-        }
+        func.inferred_error_set.errors.deinit(gpa);
+        func.inferred_error_set.functions.deinit(gpa);
     }
 
-    pub fn getInferredErrorSet(func: *Fn) ?*Type.Payload.ErrorSetInferred.Data {
-        const ret_ty = func.owner_decl.ty.fnReturnType();
-        if (ret_ty.tag() == .generic_poison) {
-            return null;
-        }
-        if (ret_ty.zigTypeTag() == .ErrorUnion) {
-            if (ret_ty.errorUnionSet().castTag(.error_set_inferred)) |payload| {
-                return &payload.data;
-            }
+    pub fn addErrorSet(func: *Fn, gpa: Allocator, err_set_ty: Type) !void {
+        switch (err_set_ty.tag()) {
+            .error_set => {
+                const names = err_set_ty.castTag(.error_set).?.data.names.keys();
+                for (names) |name| {
+                    try func.inferred_error_set.errors.put(gpa, name, {});
+                }
+            },
+            .error_set_single => {
+                const name = err_set_ty.castTag(.error_set_single).?.data;
+                try func.inferred_error_set.errors.put(gpa, name, {});
+            },
+            .error_set_inferred => {
+                const dependent_func = err_set_ty.castTag(.error_set_inferred).?.data;
+                try func.inferred_error_set.functions.put(gpa, dependent_func, {});
+            },
+            .error_set_merged => {
+                const names = err_set_ty.castTag(.error_set_merged).?.data.keys();
+                for (names) |name| {
+                    try func.inferred_error_set.errors.put(gpa, name, {});
+                }
+            },
+            .anyerror => {
+                func.inferred_error_set.is_anyerror = true;
+            },
+            else => unreachable,
         }
-        return null;
     }
 };
 
src/Sema.zig
@@ -5107,12 +5107,7 @@ fn funcCommon(
         const return_type = if (!inferred_error_set or bare_return_type.tag() == .generic_poison)
             bare_return_type
         else blk: {
-            const error_set_ty = try Type.Tag.error_set_inferred.create(sema.arena, .{
-                .func = new_func,
-                .map = .{},
-                .functions = .{},
-                .is_anyerror = false,
-            });
+            const error_set_ty = try Type.Tag.error_set_inferred.create(sema.arena, new_func);
             break :blk try Type.Tag.error_union.create(sema.arena, .{
                 .error_set = error_set_ty,
                 .payload = bare_return_type,
@@ -9209,14 +9204,14 @@ fn analyzeRet(
     // add the error tag to the inferred error set of the in-scope function, so
     // that the coercion below works correctly.
     if (sema.fn_ret_ty.zigTypeTag() == .ErrorUnion) {
-        if (sema.fn_ret_ty.errorUnionSet().castTag(.error_set_inferred)) |payload| {
+        if (sema.fn_ret_ty.errorUnionSet().tag() == .error_set_inferred) {
             const op_ty = sema.typeOf(uncasted_operand);
             switch (op_ty.zigTypeTag()) {
                 .ErrorSet => {
-                    try payload.data.addErrorSet(sema.gpa, op_ty);
+                    try sema.func.?.addErrorSet(sema.gpa, op_ty);
                 },
                 .ErrorUnion => {
-                    try payload.data.addErrorSet(sema.gpa, op_ty.errorUnionSet());
+                    try sema.func.?.addErrorSet(sema.gpa, op_ty.errorUnionSet());
                 },
                 else => {},
             }
@@ -12501,10 +12496,10 @@ fn coerceInMemoryAllowedErrorSets(
     // of inferred error sets.
     if (src_ty.castTag(.error_set_inferred)) |src_payload| {
         if (dest_ty.castTag(.error_set_inferred)) |dst_payload| {
-            const src_func = src_payload.data.func;
-            const dst_func = dst_payload.data.func;
+            const src_func = src_payload.data;
+            const dst_func = dst_payload.data;
 
-            if (src_func == dst_func or dst_payload.data.functions.contains(src_func)) {
+            if (src_func == dst_func or dst_func.inferred_error_set.functions.contains(src_func)) {
                 return .ok;
             }
         }
@@ -13899,10 +13894,10 @@ fn wrapErrorUnion(
                 }
             },
             .error_set_inferred => ok: {
-                const err_set_payload = dest_err_set_ty.castTag(.error_set_inferred).?.data;
-                if (err_set_payload.is_anyerror) break :ok;
+                const func = dest_err_set_ty.castTag(.error_set_inferred).?.data;
+                if (func.inferred_error_set.is_anyerror) break :ok;
                 const expected_name = val.castTag(.@"error").?.data.name;
-                if (err_set_payload.map.contains(expected_name)) break :ok;
+                if (func.inferred_error_set.errors.contains(expected_name)) break :ok;
                 // TODO error set resolution here before emitting a compile error
                 return sema.failWithErrorSetCodeMissing(block, inst_src, dest_err_set_ty, inst_ty);
             },
src/type.zig
@@ -627,7 +627,7 @@ pub const Type = extern union {
                 }
 
                 if (a.tag() == .error_set_inferred and b.tag() == .error_set_inferred) {
-                    return a.castTag(.error_set_inferred).?.data.func == b.castTag(.error_set_inferred).?.data.func;
+                    return a.castTag(.error_set_inferred).?.data == b.castTag(.error_set_inferred).?.data;
                 }
 
                 if (a.tag() == .error_set_single and b.tag() == .error_set_single) {
@@ -1203,7 +1203,7 @@ pub const Type = extern union {
                     return writer.writeAll(std.mem.sliceTo(error_set.owner_decl.name, 0));
                 },
                 .error_set_inferred => {
-                    const func = ty.castTag(.error_set_inferred).?.data.func;
+                    const func = ty.castTag(.error_set_inferred).?.data;
                     return writer.print("(inferred error set of {s})", .{func.owner_decl.name});
                 },
                 .error_set_merged => {
@@ -2869,7 +2869,7 @@ pub const Type = extern union {
     pub fn isAnyError(ty: Type) bool {
         return switch (ty.tag()) {
             .anyerror => true,
-            .error_set_inferred => ty.castTag(.error_set_inferred).?.data.is_anyerror,
+            .error_set_inferred => ty.castTag(.error_set_inferred).?.data.inferred_error_set.is_anyerror,
             else => false,
         };
     }
@@ -4156,50 +4156,7 @@ pub const Type = extern union {
             pub const base_tag = Tag.error_set_inferred;
 
             base: Payload = Payload{ .tag = base_tag },
-            data: Data,
-
-            pub const Data = struct {
-                func: *Module.Fn,
-                /// Direct additions to the inferred error set via `return error.Foo;`.
-                map: std.StringHashMapUnmanaged(void),
-                /// Other functions with inferred error sets which this error set includes.
-                functions: std.AutoHashMapUnmanaged(*Module.Fn, void),
-                is_anyerror: bool,
-
-                pub fn addErrorSet(self: *Data, gpa: Allocator, err_set_ty: Type) !void {
-                    switch (err_set_ty.tag()) {
-                        .error_set => {
-                            const names = err_set_ty.castTag(.error_set).?.data.names.keys();
-                            for (names) |name| {
-                                try self.map.put(gpa, name, {});
-                            }
-                        },
-                        .error_set_single => {
-                            const name = err_set_ty.castTag(.error_set_single).?.data;
-                            try self.map.put(gpa, name, {});
-                        },
-                        .error_set_inferred => {
-                            const func = err_set_ty.castTag(.error_set_inferred).?.data.func;
-                            try self.functions.put(gpa, func, {});
-                            var it = func.owner_decl.ty.fnReturnType().errorUnionSet()
-                                .castTag(.error_set_inferred).?.data.map.iterator();
-                            while (it.next()) |entry| {
-                                try self.map.put(gpa, entry.key_ptr.*, {});
-                            }
-                        },
-                        .error_set_merged => {
-                            const names = err_set_ty.castTag(.error_set_merged).?.data.keys();
-                            for (names) |name| {
-                                try self.map.put(gpa, name, {});
-                            }
-                        },
-                        .anyerror => {
-                            self.is_anyerror = true;
-                        },
-                        else => unreachable,
-                    }
-                }
-            };
+            data: *Module.Fn,
         };
 
         pub const Pointer = struct {