Commit 927f6ec8ca

Andrew Kelley <andrew@ziglang.org>
2023-07-13 09:40:16
frontend: fix inferred error sets of comptime/inline calls
Previously, they shared function index with the owner decl, but that would clobber the data stored for inferred error sets of runtime calls. Now there is an adhoc_inferred_error_set_type which models the problem much more correctly.
1 parent 82db06f
src/Air.zig
@@ -946,6 +946,7 @@ pub const Inst = struct {
         slice_const_u8_sentinel_0_type = @intFromEnum(InternPool.Index.slice_const_u8_sentinel_0_type),
         optional_noreturn_type = @intFromEnum(InternPool.Index.optional_noreturn_type),
         anyerror_void_error_union_type = @intFromEnum(InternPool.Index.anyerror_void_error_union_type),
+        adhoc_inferred_error_set_type = @intFromEnum(InternPool.Index.adhoc_inferred_error_set_type),
         generic_poison_type = @intFromEnum(InternPool.Index.generic_poison_type),
         empty_struct_type = @intFromEnum(InternPool.Index.empty_struct_type),
         undef = @intFromEnum(InternPool.Index.undef),
src/InternPool.zig
@@ -1450,6 +1450,8 @@ pub const Index = enum(u32) {
     slice_const_u8_sentinel_0_type,
     optional_noreturn_type,
     anyerror_void_error_union_type,
+    /// Used for the inferred error set of inline/comptime function calls.
+    adhoc_inferred_error_set_type,
     generic_poison_type,
     /// `@TypeOf(.{})`
     empty_struct_type,
@@ -1886,6 +1888,8 @@ pub const static_keys = [_]Key{
         .payload_type = .void_type,
     } },
 
+    // adhoc_inferred_error_set_type
+    .{ .simple_type = .adhoc_inferred_error_set },
     // generic_poison_type
     .{ .simple_type = .generic_poison },
 
@@ -2496,6 +2500,7 @@ pub const SimpleType = enum(u32) {
     extern_options,
     type_info,
 
+    adhoc_inferred_error_set,
     generic_poison,
 };
 
@@ -5812,14 +5817,17 @@ pub fn isOptionalType(ip: *const InternPool, ty: Index) bool {
 
 /// includes .inferred_error_set_type
 pub fn isErrorSetType(ip: *const InternPool, ty: Index) bool {
-    return ty == .anyerror_type or switch (ip.indexToKey(ty)) {
-        .error_set_type, .inferred_error_set_type => true,
-        else => false,
+    return switch (ty) {
+        .anyerror_type, .adhoc_inferred_error_set_type => true,
+        else => switch (ip.indexToKey(ty)) {
+            .error_set_type, .inferred_error_set_type => true,
+            else => false,
+        },
     };
 }
 
 pub fn isInferredErrorSetType(ip: *const InternPool, ty: Index) bool {
-    return ip.indexToKey(ty) == .inferred_error_set_type;
+    return ty == .adhoc_inferred_error_set_type or ip.indexToKey(ty) == .inferred_error_set_type;
 }
 
 pub fn isErrorUnionType(ip: *const InternPool, ty: Index) bool {
@@ -6412,6 +6420,7 @@ pub fn typeOf(ip: *const InternPool, index: Index) Index {
         .slice_const_u8_sentinel_0_type,
         .optional_noreturn_type,
         .anyerror_void_error_union_type,
+        .adhoc_inferred_error_set_type,
         .generic_poison_type,
         .empty_struct_type,
         => .type_type,
@@ -6688,7 +6697,7 @@ pub fn zigTypeTagOrPoison(ip: *const InternPool, index: Index) error{GenericPois
         .bool_type => .Bool,
         .void_type => .Void,
         .type_type => .Type,
-        .anyerror_type => .ErrorSet,
+        .anyerror_type, .adhoc_inferred_error_set_type => .ErrorSet,
         .comptime_int_type => .ComptimeInt,
         .comptime_float_type => .ComptimeFloat,
         .noreturn_type => .NoReturn,
src/Sema.zig
@@ -134,16 +134,19 @@ pub const default_reference_trace_len = 2;
 
 pub const InferredErrorSet = struct {
     /// The function body from which this error set originates.
+    /// This is `none` in the case of a comptime/inline function call, corresponding to
+    /// `InternPool.Index.adhoc_inferred_error_set_type`.
+    /// The function's resolved error set is not set until analysis of the
+    /// function body completes.
     func: InternPool.Index,
-
     /// All currently known errors that this error set contains. This includes
     /// direct additions via `return error.Foo;`, and possibly also errors that
-    /// are returned from any dependent functions. When the inferred error set is
-    /// fully resolved, this map contains all the errors that the function might return.
+    /// are returned from any dependent functions.
     errors: NameMap = .{},
-
     /// Other inferred error sets which this inferred error set should include.
     inferred_error_sets: std.AutoArrayHashMapUnmanaged(InternPool.Index, void) = .{},
+    /// The regular error set created by resolving this inferred error set.
+    resolved: InternPool.Index = .none,
 
     pub const NameMap = std.AutoArrayHashMapUnmanaged(InternPool.NullTerminatedString, void);
 
@@ -155,7 +158,7 @@ pub const InferredErrorSet = struct {
     ) !void {
         switch (err_set_ty.toIntern()) {
             .anyerror_type => {
-                ip.funcIesResolved(self.func).* = .anyerror_type;
+                self.resolved = .anyerror_type;
             },
             else => switch (ip.indexToKey(err_set_ty.toIntern())) {
                 .error_set_type => |error_set_type| {
@@ -7060,7 +7063,6 @@ fn analyzeCall(
                 .error_set_type = error_set_ty,
                 .payload_type = bare_return_type.toIntern(),
             } })).toType();
-            ip.funcIesResolved(module_fn_index).* = .none;
         }
 
         // This `res2` is here instead of directly breaking from `res` due to a stage1
@@ -7123,7 +7125,9 @@ fn analyzeCall(
                 break :result try sema.analyzeBlockBody(block, call_src, &child_block, merges);
             };
 
-            if (!is_comptime_call and !block.is_typeof and sema.typeOf(result).zigTypeTag(mod) != .NoReturn) {
+            if (!is_comptime_call and !block.is_typeof and
+                sema.typeOf(result).zigTypeTag(mod) != .NoReturn)
+            {
                 try sema.emitDbgInline(
                     block,
                     module_fn_index,
@@ -7137,13 +7141,23 @@ fn analyzeCall(
                 const result_val = try sema.resolveConstMaybeUndefVal(block, .unneeded, result, "");
                 const result_interned = try result_val.intern(sema.fn_ret_ty, mod);
 
+                // Transform ad-hoc inferred error set types into concrete error sets.
+                const result_transformed = try sema.resolveAdHocInferredErrorSet(block, call_src, result_interned);
+
                 // TODO: check whether any external comptime memory was mutated by the
                 // comptime function call. If so, then do not memoize the call here.
                 _ = try mod.intern(.{ .memoized_call = .{
                     .func = module_fn_index,
                     .arg_values = memoized_arg_values,
-                    .result = result_interned,
+                    .result = result_transformed,
                 } });
+
+                break :res2 Air.internedToRef(result_transformed);
+            }
+
+            if (sema.fn_ret_ty_ies) |ies| {
+                _ = ies;
+                @panic("TODO: resolve ad-hoc inferred error set");
             }
 
             break :res2 result;
@@ -18237,19 +18251,30 @@ fn zirRestoreErrRetIndex(sema: *Sema, start_block: *Block, inst: Zir.Inst.Index)
 
 fn addToInferredErrorSet(sema: *Sema, uncasted_operand: Air.Inst.Ref) !void {
     const mod = sema.mod;
-    const gpa = sema.gpa;
     const ip = &mod.intern_pool;
     assert(sema.fn_ret_ty.zigTypeTag(mod) == .ErrorUnion);
+    const err_set_ty = sema.fn_ret_ty.errorUnionSet(mod).toIntern();
+    switch (err_set_ty) {
+        .adhoc_inferred_error_set_type => {
+            const ies = sema.fn_ret_ty_ies.?;
+            assert(ies.func == .none);
+            try addToInferredErrorSetPtr(mod, ies, sema.typeOf(uncasted_operand));
+        },
+        else => if (ip.isInferredErrorSetType(err_set_ty)) {
+            const ies = sema.fn_ret_ty_ies.?;
+            assert(ies.func == sema.func_index);
+            try addToInferredErrorSetPtr(mod, ies, sema.typeOf(uncasted_operand));
+        },
+    }
+}
 
-    if (ip.isInferredErrorSetType(sema.fn_ret_ty.errorUnionSet(mod).toIntern())) {
-        const ies = sema.fn_ret_ty_ies.?;
-        assert(ies.func == sema.func_index);
-        const op_ty = sema.typeOf(uncasted_operand);
-        switch (op_ty.zigTypeTag(mod)) {
-            .ErrorSet => try ies.addErrorSet(op_ty, ip, gpa),
-            .ErrorUnion => try ies.addErrorSet(op_ty.errorUnionSet(mod), ip, gpa),
-            else => {},
-        }
+fn addToInferredErrorSetPtr(mod: *Module, ies: *InferredErrorSet, op_ty: Type) !void {
+    const gpa = mod.gpa;
+    const ip = &mod.intern_pool;
+    switch (op_ty.zigTypeTag(mod)) {
+        .ErrorSet => try ies.addErrorSet(op_ty, ip, gpa),
+        .ErrorUnion => try ies.addErrorSet(op_ty.errorUnionSet(mod), ip, gpa),
+        else => {},
     }
 }
 
@@ -27936,6 +27961,14 @@ fn coerceInMemoryAllowedErrorSets(
         return .ok;
     }
 
+    if (dest_ty.toIntern() == .adhoc_inferred_error_set_type) {
+        // We are trying to coerce an error set to the current function's
+        // inferred error set.
+        const dst_ies = sema.fn_ret_ty_ies.?;
+        try dst_ies.addErrorSet(src_ty, ip, gpa);
+        return .ok;
+    }
+
     if (ip.isInferredErrorSetType(dest_ty.toIntern())) {
         const dst_ies_func_index = ip.iesFuncIndex(dest_ty.toIntern());
         if (sema.fn_ret_ty_ies) |dst_ies| {
@@ -27946,7 +27979,6 @@ fn coerceInMemoryAllowedErrorSets(
                 return .ok;
             }
         }
-
         switch (try sema.resolveInferredErrorSet(block, dest_src, dest_ty.toIntern())) {
             // isAnyError might have changed from a false negative to a true
             // positive after resolution.
@@ -30551,21 +30583,25 @@ fn analyzeIsNonErrComptimeOnly(
                     else => |i| if (ip.indexToKey(i).error_set_type.names.len != 0) break :blk,
                 }
                 if (maybe_operand_val == null) {
-                    if (sema.fn_ret_ty_ies) |ies| if (ies.func == func_index) {
-                        // Try to avoid resolving inferred error set if possible.
-                        for (ies.inferred_error_sets.keys()) |other_ies_index| {
-                            if (set_ty == other_ies_index) continue;
-                            const other_resolved =
-                                try sema.resolveInferredErrorSet(block, src, other_ies_index);
-                            if (other_resolved == .anyerror_type) {
-                                ip.funcIesResolved(func_index).* = .anyerror_type;
-                                break :blk;
+                    if (sema.fn_ret_ty_ies) |ies| {
+                        if (set_ty == .adhoc_inferred_error_set_type or
+                            ies.func == func_index)
+                        {
+                            // Try to avoid resolving inferred error set if possible.
+                            for (ies.inferred_error_sets.keys()) |other_ies_index| {
+                                if (set_ty == other_ies_index) continue;
+                                const other_resolved =
+                                    try sema.resolveInferredErrorSet(block, src, other_ies_index);
+                                if (other_resolved == .anyerror_type) {
+                                    ip.funcIesResolved(func_index).* = .anyerror_type;
+                                    break :blk;
+                                }
+                                if (ip.indexToKey(other_resolved).error_set_type.names.len != 0)
+                                    break :blk;
                             }
-                            if (ip.indexToKey(other_resolved).error_set_type.names.len != 0)
-                                break :blk;
+                            return .bool_true;
                         }
-                        return .bool_true;
-                    };
+                    }
                     const resolved_ty = try sema.resolveInferredErrorSet(block, src, set_ty);
                     if (resolved_ty == .anyerror_type)
                         break :blk;
@@ -31520,18 +31556,30 @@ fn wrapErrorUnionSet(
     const inst_ty = sema.typeOf(inst);
     const dest_err_set_ty = dest_ty.errorUnionSet(mod);
     if (try sema.resolveMaybeUndefVal(inst)) |val| {
+        const expected_name = mod.intern_pool.indexToKey(val.toIntern()).err.name;
         switch (dest_err_set_ty.toIntern()) {
             .anyerror_type => {},
+            .adhoc_inferred_error_set_type => ok: {
+                const ies = sema.fn_ret_ty_ies.?;
+                switch (ies.resolved) {
+                    .anyerror_type => break :ok,
+                    .none => if (.ok == try sema.coerceInMemoryAllowedErrorSets(block, dest_err_set_ty, inst_ty, inst_src, inst_src)) {
+                        break :ok;
+                    },
+                    else => |i| if (ip.indexToKey(i).error_set_type.nameIndex(ip, expected_name) != null) {
+                        break :ok;
+                    },
+                }
+                return sema.failWithErrorSetCodeMissing(block, inst_src, dest_err_set_ty, inst_ty);
+            },
             else => switch (ip.indexToKey(dest_err_set_ty.toIntern())) {
                 .error_set_type => |error_set_type| ok: {
-                    const expected_name = mod.intern_pool.indexToKey(val.toIntern()).err.name;
                     if (error_set_type.nameIndex(ip, expected_name) != null) break :ok;
                     return sema.failWithErrorSetCodeMissing(block, inst_src, dest_err_set_ty, inst_ty);
                 },
                 .inferred_error_set_type => |func_index| ok: {
                     // We carefully do this in an order that avoids unnecessarily
                     // resolving the destination error set type.
-                    const expected_name = mod.intern_pool.indexToKey(val.toIntern()).err.name;
                     switch (ip.funcIesResolved(func_index).*) {
                         .anyerror_type => break :ok,
                         .none => if (.ok == try sema.coerceInMemoryAllowedErrorSets(block, dest_err_set_ty, inst_ty, inst_src, inst_src)) {
@@ -31549,9 +31597,7 @@ fn wrapErrorUnionSet(
         }
         return sema.addConstant((try mod.intern(.{ .error_union = .{
             .ty = dest_ty.toIntern(),
-            .val = .{
-                .err_name = mod.intern_pool.indexToKey(try val.intern(dest_err_set_ty, mod)).err.name,
-            },
+            .val = .{ .err_name = expected_name },
         } })).toValue());
     }
 
@@ -33033,7 +33079,11 @@ pub fn resolveFnTypes(sema: *Sema, block: *Block, src: LazySrcLoc, fn_ty: Type)
     const ip = &mod.intern_pool;
     const fn_ty_info = mod.typeToFunc(fn_ty).?;
 
-    if (sema.fn_ret_ty_ies) |ies| try sema.resolveInferredErrorSetPtr(block, src, ies);
+    if (sema.fn_ret_ty_ies) |ies| {
+        try sema.resolveInferredErrorSetPtr(block, src, ies);
+        assert(ies.resolved != .none);
+        ip.funcIesResolved(sema.func_index).* = ies.resolved;
+    }
 
     try sema.resolveTypeFully(fn_ty_info.return_type.toType());
 
@@ -33565,6 +33615,7 @@ pub fn resolveTypeRequiresComptime(sema: *Sema, ty: Type) CompileError!bool {
                 .bool,
                 .void,
                 .anyerror,
+                .adhoc_inferred_error_set,
                 .noreturn,
                 .generic_poison,
                 .atomic_order,
@@ -33815,6 +33866,7 @@ pub fn resolveTypeFields(sema: *Sema, ty: Type) CompileError!Type {
         .void_type,
         .type_type,
         .anyerror_type,
+        .adhoc_inferred_error_set_type,
         .comptime_int_type,
         .comptime_float_type,
         .noreturn_type,
@@ -34032,8 +34084,7 @@ fn resolveInferredErrorSetPtr(
     const mod = sema.mod;
     const ip = &mod.intern_pool;
 
-    const func = mod.funcInfo(ies.func);
-    if (func.resolvedErrorSet(ip).* != .none) return;
+    if (ies.resolved != .none) return;
 
     const ies_index = ip.errorUnionSet(sema.fn_ret_ty.toIntern());
 
@@ -34041,7 +34092,7 @@ fn resolveInferredErrorSetPtr(
         if (ies_index == other_ies_index) continue;
         switch (try sema.resolveInferredErrorSet(block, src, other_ies_index)) {
             .anyerror_type => {
-                func.resolvedErrorSet(ip).* = .anyerror_type;
+                ies.resolved = .anyerror_type;
                 return;
             },
             else => |error_set_ty_index| {
@@ -34054,7 +34105,33 @@ fn resolveInferredErrorSetPtr(
     }
 
     const resolved_error_set_ty = try mod.errorSetFromUnsortedNames(ies.errors.keys());
-    func.resolvedErrorSet(ip).* = resolved_error_set_ty.toIntern();
+    ies.resolved = resolved_error_set_ty.toIntern();
+}
+
+fn resolveAdHocInferredErrorSet(
+    sema: *Sema,
+    block: *Block,
+    src: LazySrcLoc,
+    value: InternPool.Index,
+) CompileError!InternPool.Index {
+    const ies = sema.fn_ret_ty_ies orelse return value;
+    const mod = sema.mod;
+    const gpa = sema.gpa;
+    const ip = &mod.intern_pool;
+    const ty = ip.typeOf(value);
+    const error_union_info = switch (ip.indexToKey(ty)) {
+        .error_union_type => |x| x,
+        else => return value,
+    };
+    if (error_union_info.error_set_type != .adhoc_inferred_error_set_type)
+        return value;
+
+    try sema.resolveInferredErrorSetPtr(block, src, ies);
+    const new_ty = try ip.get(gpa, .{ .error_union_type = .{
+        .error_set_type = ies.resolved,
+        .payload_type = error_union_info.payload_type,
+    } });
+    return ip.getCoerced(gpa, value, new_ty);
 }
 
 fn resolveInferredErrorSetTy(
@@ -35037,6 +35114,7 @@ pub fn typeHasOnePossibleValue(sema: *Sema, ty: Type) CompileError!?Value {
         .bool_type,
         .type_type,
         .anyerror_type,
+        .adhoc_inferred_error_set_type,
         .comptime_int_type,
         .comptime_float_type,
         .enum_literal_type,
@@ -35692,6 +35770,7 @@ pub fn typeRequiresComptime(sema: *Sema, ty: Type) CompileError!bool {
                 .prefetch_options,
                 .export_options,
                 .extern_options,
+                .adhoc_inferred_error_set,
                 => false,
 
                 .type,
src/type.zig
@@ -292,6 +292,7 @@ pub const Type = struct {
                 .comptime_int,
                 .comptime_float,
                 .noreturn,
+                .adhoc_inferred_error_set,
                 => return writer.writeAll(@tagName(s)),
 
                 .null,
@@ -533,6 +534,7 @@ pub const Type = struct {
                     .c_longdouble,
                     .bool,
                     .anyerror,
+                    .adhoc_inferred_error_set,
                     .anyopaque,
                     .atomic_order,
                     .atomic_rmw_op,
@@ -696,6 +698,7 @@ pub const Type = struct {
                 => true,
 
                 .anyerror,
+                .adhoc_inferred_error_set,
                 .anyopaque,
                 .atomic_order,
                 .atomic_rmw_op,
@@ -954,7 +957,9 @@ pub const Type = struct {
                     },
 
                     // TODO revisit this when we have the concept of the error tag type
-                    .anyerror => return AbiAlignmentAdvanced{ .scalar = 2 },
+                    .anyerror,
+                    .adhoc_inferred_error_set,
+                    => return AbiAlignmentAdvanced{ .scalar = 2 },
 
                     .void,
                     .type,
@@ -1418,7 +1423,9 @@ pub const Type = struct {
                     => return AbiSizeAdvanced{ .scalar = 0 },
 
                     // TODO revisit this when we have the concept of the error tag type
-                    .anyerror => return AbiSizeAdvanced{ .scalar = 2 },
+                    .anyerror,
+                    .adhoc_inferred_error_set,
+                    => return AbiSizeAdvanced{ .scalar = 2 },
 
                     .prefetch_options => unreachable, // missing call to resolveTypeFields
                     .export_options => unreachable, // missing call to resolveTypeFields
@@ -1661,7 +1668,9 @@ pub const Type = struct {
                 .void => return 0,
 
                 // TODO revisit this when we have the concept of the error tag type
-                .anyerror => return 16,
+                .anyerror,
+                .adhoc_inferred_error_set,
+                => return 16,
 
                 .anyopaque => unreachable,
                 .type => unreachable,
@@ -2503,6 +2512,7 @@ pub const Type = struct {
                     .export_options,
                     .extern_options,
                     .type_info,
+                    .adhoc_inferred_error_set,
                     => return null,
 
                     .void => return Value.void,
@@ -2697,6 +2707,7 @@ pub const Type = struct {
                     .bool,
                     .void,
                     .anyerror,
+                    .adhoc_inferred_error_set,
                     .noreturn,
                     .generic_poison,
                     .atomic_order,
src/Zir.zig
@@ -2087,6 +2087,7 @@ pub const Inst = struct {
         slice_const_u8_sentinel_0_type = @intFromEnum(InternPool.Index.slice_const_u8_sentinel_0_type),
         optional_noreturn_type = @intFromEnum(InternPool.Index.optional_noreturn_type),
         anyerror_void_error_union_type = @intFromEnum(InternPool.Index.anyerror_void_error_union_type),
+        adhoc_inferred_error_set_type = @intFromEnum(InternPool.Index.adhoc_inferred_error_set_type),
         generic_poison_type = @intFromEnum(InternPool.Index.generic_poison_type),
         empty_struct_type = @intFromEnum(InternPool.Index.empty_struct_type),
         undef = @intFromEnum(InternPool.Index.undef),