Commit 6534f2ef4f

Andrew Kelley <andrew@ziglang.org>
2021-10-17 23:50:47
stage2: implement error wrapping
* Sema: fix returned operands not coercing to the function return type in some cases. - When returning an error or an error union from a function with an inferred error set, it will now populate the inferred error set. - Implement error set coercion for the common case of inferred error set to inferred error set, without forcing a full resolution. * LLVM backend: update instruction lowering that handles error unions to respect `isByRef`. - Also implement `wrap_err_union_err`.
1 parent 53b87fa
Changed files (5)
src/codegen/llvm.zig
@@ -2301,8 +2301,7 @@ pub const FuncGen = struct {
         op: llvm.IntPredicate,
         operand_is_ptr: bool,
     ) !?*const llvm.Value {
-        if (self.liveness.isUnused(inst))
-            return null;
+        if (self.liveness.isUnused(inst)) return null;
 
         const un_op = self.air.instructions.items(.data)[inst].un_op;
         const operand = try self.resolveInst(un_op);
@@ -2363,22 +2362,16 @@ pub const FuncGen = struct {
         inst: Air.Inst.Index,
         operand_is_ptr: bool,
     ) !?*const llvm.Value {
-        if (self.liveness.isUnused(inst))
-            return null;
+        if (self.liveness.isUnused(inst)) return null;
 
         const ty_op = self.air.instructions.items(.data)[inst].ty_op;
         const operand = try self.resolveInst(ty_op.operand);
         const err_union_ty = self.air.typeOf(ty_op.operand);
         const payload_ty = err_union_ty.errorUnionPayload();
-
-        if (!payload_ty.hasCodeGenBits()) {
-            return null;
-        }
-
-        if (operand_is_ptr) {
+        if (!payload_ty.hasCodeGenBits()) return null;
+        if (operand_is_ptr or isByRef(payload_ty)) {
             return self.builder.buildStructGEP(operand, 1, "");
         }
-
         return self.builder.buildExtractValue(operand, 1, "");
     }
 
@@ -2400,7 +2393,7 @@ pub const FuncGen = struct {
             return self.builder.buildLoad(operand, "");
         }
 
-        if (operand_is_ptr) {
+        if (operand_is_ptr or isByRef(payload_ty)) {
             const err_field_ptr = self.builder.buildStructGEP(operand, 0, "");
             return self.builder.buildLoad(err_field_ptr, "");
         }
@@ -2469,10 +2462,35 @@ pub const FuncGen = struct {
     }
 
     fn airWrapErrUnionErr(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
-        if (self.liveness.isUnused(inst))
-            return null;
+        if (self.liveness.isUnused(inst)) return null;
+
+        const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+        const err_un_ty = self.air.typeOfIndex(inst);
+        const payload_ty = err_un_ty.errorUnionPayload();
+        const operand = try self.resolveInst(ty_op.operand);
+        if (!payload_ty.hasCodeGenBits()) {
+            return operand;
+        }
+        const err_un_llvm_ty = try self.dg.llvmType(err_un_ty);
+        if (isByRef(err_un_ty)) {
+            const result_ptr = self.buildAlloca(err_un_llvm_ty);
+            const err_ptr = self.builder.buildStructGEP(result_ptr, 0, "");
+            _ = self.builder.buildStore(operand, err_ptr);
+            const payload_ptr = self.builder.buildStructGEP(result_ptr, 1, "");
+            var ptr_ty_payload: Type.Payload.ElemType = .{
+                .base = .{ .tag = .single_mut_pointer },
+                .data = payload_ty,
+            };
+            const payload_ptr_ty = Type.initPayload(&ptr_ty_payload.base);
+            // TODO store undef to payload_ptr
+            _ = payload_ptr;
+            _ = payload_ptr_ty;
+            return result_ptr;
+        }
 
-        return self.todo("implement llvm codegen for 'airWrapErrUnionErr'", .{});
+        const partial = self.builder.buildInsertValue(err_un_llvm_ty.getUndef(), operand, 0, "");
+        // TODO set payload bytes to undef
+        return partial;
     }
 
     fn airMin(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
src/Module.zig
@@ -782,6 +782,10 @@ pub const ErrorSet = struct {
     /// The length is given by `names_len`.
     names_ptr: [*]const []const u8,
 
+    pub fn names(self: ErrorSet) []const []const u8 {
+        return self.names_ptr[0..self.names_len];
+    }
+
     pub fn srcLoc(self: ErrorSet) SrcLoc {
         return .{
             .file_scope = self.owner_decl.getFileScope(),
src/Sema.zig
@@ -4845,6 +4845,8 @@ fn funcCommon(
             const error_set_ty = try Type.Tag.error_set_inferred.create(sema.arena, .{
                 .func = new_func,
                 .map = .{},
+                .functions = .{},
+                .is_anyerror = false,
             });
             break :blk try Type.Tag.error_union.create(sema.arena, .{
                 .error_set = error_set_ty,
@@ -8466,19 +8468,13 @@ fn zirRetErrValue(
     const err_name = inst_data.get(sema.code);
     const src = inst_data.src();
 
-    // Add the error tag to the inferred error set of the in-scope function.
-    if (sema.fn_ret_ty.zigTypeTag() == .ErrorUnion) {
-        if (sema.fn_ret_ty.errorUnionSet().castTag(.error_set_inferred)) |payload| {
-            _ = try payload.data.map.getOrPut(sema.gpa, err_name);
-        }
-    }
     // Return the error code from the function.
     const kv = try sema.mod.getErrorValue(err_name);
     const result_inst = try sema.addConstant(
         try Type.Tag.error_set_single.create(sema.arena, kv.key),
         try Value.Tag.@"error".create(sema.arena, .{ .name = kv.key }),
     );
-    return sema.analyzeRet(block, result_inst, src, true);
+    return sema.analyzeRet(block, result_inst, src);
 }
 
 fn zirRetCoerce(
@@ -8493,7 +8489,7 @@ fn zirRetCoerce(
     const operand = sema.resolveInst(inst_data.operand);
     const src = inst_data.src();
 
-    return sema.analyzeRet(block, operand, src, true);
+    return sema.analyzeRet(block, operand, src);
 }
 
 fn zirRetNode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Zir.Inst.Index {
@@ -8504,11 +8500,7 @@ fn zirRetNode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Zir
     const operand = sema.resolveInst(inst_data.operand);
     const src = inst_data.src();
 
-    // TODO: we pass false here for the `need_coercion` boolean, but I'm pretty sure we need
-    // to remove this parameter entirely. Observe the problem by looking at the incorrect compile
-    // error that occurs when a behavior test case being executed at comptime fails, e.g.
-    // `test { comptime foo(); } fn foo() { try expect(false); }`
-    return sema.analyzeRet(block, operand, src, false);
+    return sema.analyzeRet(block, operand, src);
 }
 
 fn zirRetLoad(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Zir.Inst.Index {
@@ -8521,7 +8513,7 @@ fn zirRetLoad(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Zir
 
     if (block.is_comptime or block.inlining != null) {
         const operand = try sema.analyzeLoad(block, src, ret_ptr, src);
-        return sema.analyzeRet(block, operand, src, false);
+        return sema.analyzeRet(block, operand, src);
     }
     try sema.requireRuntimeBlock(block, src);
     _ = try block.addUnOp(.ret_load, ret_ptr);
@@ -8533,12 +8525,25 @@ fn analyzeRet(
     block: *Block,
     uncasted_operand: Air.Inst.Ref,
     src: LazySrcLoc,
-    need_coercion: bool,
 ) CompileError!Zir.Inst.Index {
-    const operand = if (!need_coercion)
-        uncasted_operand
-    else
-        try sema.coerce(block, sema.fn_ret_ty, uncasted_operand, src);
+    // Special case for returning an error to an inferred error set; we need to
+    // add the error tag to the inferred error set of the in-scope function, so
+    // that the coercion below works correctly.
+    if (sema.fn_ret_ty.zigTypeTag() == .ErrorUnion) {
+        if (sema.fn_ret_ty.errorUnionSet().castTag(.error_set_inferred)) |payload| {
+            const op_ty = sema.typeOf(uncasted_operand);
+            switch (op_ty.zigTypeTag()) {
+                .ErrorSet => {
+                    try payload.data.addErrorSet(sema.gpa, op_ty);
+                },
+                .ErrorUnion => {
+                    try payload.data.addErrorSet(sema.gpa, op_ty.errorUnionSet());
+                },
+                else => {},
+            }
+        }
+    }
+    const operand = try sema.coerce(block, sema.fn_ret_ty, uncasted_operand, src);
 
     if (block.inlining) |inlining| {
         if (block.is_comptime) {
@@ -11605,14 +11610,30 @@ fn coerce(
             // T to E!T or E to E!T
             return sema.wrapErrorUnion(block, dest_ty, inst, inst_src);
         },
-        .ErrorSet => {
-            // Coercion to `anyerror`.
-            // TODO If the dest type tag is not `anyerror` it still could
-            // resolve to anyerror. `dest_ty` needs to have inferred error set resolution
-            // happen before this check.
-            if (dest_ty.tag() == .anyerror and inst_ty.zigTypeTag() == .ErrorSet) {
-                return sema.coerceErrSetToAnyError(block, inst, inst_src);
-            }
+        .ErrorSet => switch (inst_ty.zigTypeTag()) {
+            .ErrorSet => {
+                // Coercion to `anyerror`. Note that this check can return false positives
+                // in case the error sets did not get resolved.
+                if (dest_ty.isAnyError()) {
+                    return sema.coerceCompatibleErrorSets(block, inst, inst_src);
+                }
+                // If both are inferred error sets of functions, and
+                // the dest includes the source function, the coercion is OK.
+                // This check is important because it works without forcing a full resolution
+                // of inferred error sets.
+                if (inst_ty.castTag(.error_set_inferred)) |src_payload| {
+                    if (dest_ty.castTag(.error_set_inferred)) |dst_payload| {
+                        const src_func = src_payload.data.func;
+                        const dst_func = dst_payload.data.func;
+
+                        if (src_func == dst_func or dst_payload.data.functions.contains(src_func)) {
+                            return sema.coerceCompatibleErrorSets(block, inst, inst_src);
+                        }
+                    }
+                }
+                // TODO full error set resolution and compare sets by names.
+            },
+            else => {},
         },
         .Union => switch (inst_ty.zigTypeTag()) {
             .Enum, .EnumLiteral => return sema.coerceEnumToUnion(block, dest_ty, dest_ty_src, inst, inst_src),
@@ -12245,7 +12266,7 @@ fn coerceVectorToArray(
     return block.addTyOp(.bitcast, array_ty, vector);
 }
 
-fn coerceErrSetToAnyError(
+fn coerceCompatibleErrorSets(
     sema: *Sema,
     block: *Block,
     err_set: Air.Inst.Ref,
src/type.zig
@@ -2619,6 +2619,17 @@ pub const Type = extern union {
         };
     }
 
+    /// Returns true if it is an error set that includes anyerror, false otherwise.
+    /// Note that the result may be a false negative if the type did not get error set
+    /// resolution prior to this call.
+    pub fn isAnyError(ty: Type) bool {
+        return switch (ty.tag()) {
+            .anyerror => true,
+            .error_set_inferred => ty.castTag(.error_set_inferred).?.data.is_anyerror,
+            else => false,
+        };
+    }
+
     /// Asserts the type is an array or vector.
     pub fn arrayLen(ty: Type) u64 {
         return switch (ty.tag()) {
@@ -3871,10 +3882,39 @@ pub const Type = extern union {
             pub const base_tag = Tag.error_set_inferred;
 
             base: Payload = Payload{ .tag = base_tag },
-            data: struct {
+            data: Data,
+
+            pub const Data = struct {
                 func: *Module.Fn,
+                /// Direct additions to the inferred error set via `return error.Foo;`.
                 map: std.StringHashMapUnmanaged(void),
-            },
+                /// Other functions with inferred error sets which this error set includes.
+                functions: std.AutoHashMapUnmanaged(*Module.Fn, void),
+                is_anyerror: bool,
+
+                pub fn addErrorSet(self: *Data, gpa: *Allocator, err_set_ty: Type) !void {
+                    switch (err_set_ty.tag()) {
+                        .error_set => {
+                            const names = err_set_ty.castTag(.error_set).?.data.names();
+                            for (names) |name| {
+                                try self.map.put(gpa, name, {});
+                            }
+                        },
+                        .error_set_single => {
+                            const name = err_set_ty.castTag(.error_set_single).?.data;
+                            try self.map.put(gpa, name, {});
+                        },
+                        .error_set_inferred => {
+                            const func = err_set_ty.castTag(.error_set_inferred).?.data.func;
+                            try self.functions.put(gpa, func, {});
+                        },
+                        .anyerror => {
+                            self.is_anyerror = true;
+                        },
+                        else => unreachable,
+                    }
+                }
+            };
         };
 
         pub const Pointer = struct {
test/behavior/error.zig
@@ -31,3 +31,21 @@ test "empty error union" {
     const x = error{} || error{};
     _ = x;
 }
+
+pub fn foo() anyerror!i32 {
+    const x = try bar();
+    return x + 1;
+}
+
+pub fn bar() anyerror!i32 {
+    return 13;
+}
+
+pub fn baz() anyerror!i32 {
+    const y = foo() catch 1234;
+    return y + 1;
+}
+
+test "error wrapping" {
+    try expect((baz() catch unreachable) == 15);
+}