Commit b9c2837c1c

Veikka Tuominen <git@vexu.eu>
2023-02-02 13:16:15
Sema: validate inferred error set payload type
This was missed in b0a55e1b3be3a274546f9c18016e9609d546bdb0
1 parent a5d25fa
src/Sema.zig
@@ -7669,17 +7669,21 @@ fn zirErrorUnionType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileEr
             error_set.fmt(sema.mod),
         });
     }
-    if (payload.zigTypeTag() == .Opaque) {
-        return sema.fail(block, rhs_src, "error union with payload of opaque type '{}' not allowed", .{
-            payload.fmt(sema.mod),
+    try sema.validateErrorUnionPayloadType(block, payload, rhs_src);
+    const err_union_ty = try Type.errorUnion(sema.arena, error_set, payload, sema.mod);
+    return sema.addType(err_union_ty);
+}
+
+fn validateErrorUnionPayloadType(sema: *Sema, block: *Block, payload_ty: Type, payload_src: LazySrcLoc) !void {
+    if (payload_ty.zigTypeTag() == .Opaque) {
+        return sema.fail(block, payload_src, "error union with payload of opaque type '{}' not allowed", .{
+            payload_ty.fmt(sema.mod),
         });
-    } else if (payload.zigTypeTag() == .ErrorSet) {
-        return sema.fail(block, rhs_src, "error union with payload of error set type '{}' not allowed", .{
-            payload.fmt(sema.mod),
+    } else if (payload_ty.zigTypeTag() == .ErrorSet) {
+        return sema.fail(block, payload_src, "error union with payload of error set type '{}' not allowed", .{
+            payload_ty.fmt(sema.mod),
         });
     }
-    const err_union_ty = try Type.errorUnion(sema.arena, error_set, payload, sema.mod);
-    return sema.addType(err_union_ty);
 }
 
 fn zirErrorValue(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
@@ -8639,6 +8643,7 @@ fn funcCommon(
         const return_type = if (!inferred_error_set or ret_poison)
             bare_return_type
         else blk: {
+            try sema.validateErrorUnionPayloadType(block, bare_return_type, ret_ty_src);
             const node = try sema.gpa.create(Module.Fn.InferredErrorSetListNode);
             node.data = .{ .func = new_func };
             maybe_inferred_error_set_node = node;
@@ -8650,15 +8655,15 @@ fn funcCommon(
             });
         };
 
-        if (!bare_return_type.isValidReturnType()) {
-            const opaque_str = if (bare_return_type.zigTypeTag() == .Opaque) "opaque " else "";
+        if (!return_type.isValidReturnType()) {
+            const opaque_str = if (return_type.zigTypeTag() == .Opaque) "opaque " else "";
             const msg = msg: {
                 const msg = try sema.errMsg(block, ret_ty_src, "{s}return type '{}' not allowed", .{
-                    opaque_str, bare_return_type.fmt(sema.mod),
+                    opaque_str, return_type.fmt(sema.mod),
                 });
                 errdefer msg.destroy(sema.gpa);
 
-                try sema.addDeclaredHereNote(msg, bare_return_type);
+                try sema.addDeclaredHereNote(msg, return_type);
                 break :msg msg;
             };
             return sema.failWithOwnedErrorMsg(msg);
test/cases/compile_errors/function_returning_opaque_type.zig
@@ -1,11 +1,11 @@
 const FooType = opaque {};
-export fn bar() !FooType {
+export fn bar() FooType {
     return error.InvalidValue;
 }
-export fn bav() !@TypeOf(null) {
+export fn bav() @TypeOf(null) {
     return error.InvalidValue;
 }
-export fn baz() !@TypeOf(undefined) {
+export fn baz() @TypeOf(undefined) {
     return error.InvalidValue;
 }
 
@@ -13,7 +13,7 @@ export fn baz() !@TypeOf(undefined) {
 // backend=stage2
 // target=native
 //
-// :2:18: error: opaque return type 'tmp.FooType' not allowed
+// :2:17: error: opaque return type 'tmp.FooType' not allowed
 // :1:17: note: opaque declared here
-// :5:18: error: return type '@TypeOf(null)' not allowed
-// :8:18: error: return type '@TypeOf(undefined)' not allowed
+// :5:17: error: return type '@TypeOf(null)' not allowed
+// :8:17: error: return type '@TypeOf(undefined)' not allowed
test/cases/compile_errors/invalid_error_union_payload_type.zig
@@ -4,6 +4,12 @@ comptime {
 comptime {
     _ = anyerror!anyerror;
 }
+fn someFunction() !anyerror {
+    return error.C;
+}
+comptime {
+    _ = someFunction;
+}
 
 // error
 // backend=stage2
@@ -11,3 +17,4 @@ comptime {
 //
 // :2:18: error: error union with payload of opaque type 'anyopaque' not allowed
 // :5:18: error: error union with payload of error set type 'anyerror' not allowed
+// :7:20: error: error union with payload of error set type 'anyerror' not allowed