Commit c3d5428cba

Veikka Tuominen <git@vexu.eu>
2022-08-17 14:07:20
Sema: properly handle noreturn fields in unions
1 parent b0a55e1
src/Sema.zig
@@ -3772,6 +3772,7 @@ fn validateStructInit(
     }
 
     var root_msg: ?*Module.ErrorMsg = null;
+    errdefer if (root_msg) |msg| msg.destroy(sema.gpa);
 
     const struct_ptr = try sema.resolveInst(struct_ptr_zir_ref);
     if ((is_comptime or block.is_comptime) and
@@ -3947,6 +3948,7 @@ fn validateStructInit(
     }
 
     if (root_msg) |msg| {
+        root_msg = null;
         if (struct_ty.castTag(.@"struct")) |struct_obj| {
             const fqn = try struct_obj.data.getFullyQualifiedName(sema.mod);
             defer gpa.free(fqn);
@@ -4005,6 +4007,8 @@ fn zirValidateArrayInit(
     if (instrs.len != array_len and array_ty.isTuple()) {
         const struct_obj = array_ty.castTag(.tuple).?.data;
         var root_msg: ?*Module.ErrorMsg = null;
+        errdefer if (root_msg) |msg| msg.destroy(sema.gpa);
+
         for (struct_obj.values) |default_val, i| {
             if (i < instrs.len) continue;
 
@@ -4019,6 +4023,7 @@ fn zirValidateArrayInit(
         }
 
         if (root_msg) |msg| {
+            root_msg = null;
             return sema.failWithOwnedErrorMsg(msg);
         }
     }
@@ -8964,12 +8969,15 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         },
     };
 
-    const union_originally = blk: {
+    const maybe_union_ty = blk: {
         const zir_data = sema.code.instructions.items(.data);
         const cond_index = Zir.refToIndex(extra.data.operand).?;
         const raw_operand = sema.resolveInst(zir_data[cond_index].un_node.operand) catch unreachable;
-        break :blk sema.typeOf(raw_operand).zigTypeTag() == .Union;
+        break :blk sema.typeOf(raw_operand);
     };
+    const union_originally = maybe_union_ty.zigTypeTag() == .Union;
+    var seen_union_fields: []?Module.SwitchProngSrc = &.{};
+    defer gpa.free(seen_union_fields);
 
     const operand_ty = sema.typeOf(operand);
 
@@ -9004,7 +9012,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         .Union => unreachable, // handled in zirSwitchCond
         .Enum => {
             var seen_fields = try gpa.alloc(?Module.SwitchProngSrc, operand_ty.enumFieldCount());
-            defer gpa.free(seen_fields);
+            defer if (!union_originally) gpa.free(seen_fields);
+            if (union_originally) seen_union_fields = seen_fields;
             mem.set(?Module.SwitchProngSrc, seen_fields, null);
 
             // This is used for non-exhaustive enum values that do not correspond to any tags.
@@ -9637,18 +9646,28 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         const item = try sema.resolveInst(item_ref);
         // `item` is already guaranteed to be constant known.
 
-        _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) {
-            error.ComptimeBreak => {
-                const zir_datas = sema.code.instructions.items(.data);
-                const break_data = zir_datas[sema.comptime_break_inst].@"break";
-                try sema.addRuntimeBreak(&case_block, .{
-                    .block_inst = break_data.block_inst,
-                    .operand = break_data.operand,
-                    .inst = sema.comptime_break_inst,
-                });
-            },
-            else => |e| return e,
-        };
+        const analyze_body = if (union_originally) blk: {
+            const item_val = sema.resolveConstValue(block, .unneeded, item, undefined) catch unreachable;
+            const field_ty = maybe_union_ty.unionFieldType(item_val, sema.mod);
+            break :blk field_ty.zigTypeTag() != .NoReturn;
+        } else true;
+
+        if (analyze_body) {
+            _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) {
+                error.ComptimeBreak => {
+                    const zir_datas = sema.code.instructions.items(.data);
+                    const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                    try sema.addRuntimeBreak(&case_block, .{
+                        .block_inst = break_data.block_inst,
+                        .operand = break_data.operand,
+                        .inst = sema.comptime_break_inst,
+                    });
+                },
+                else => |e| return e,
+            };
+        } else {
+            _ = try case_block.addNoOp(.unreach);
+        }
 
         try wip_captures.finalize();
 
@@ -9689,20 +9708,34 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         if (ranges_len == 0) {
             cases_len += 1;
 
+            const analyze_body = if (union_originally)
+                for (items) |item_ref| {
+                    const item = try sema.resolveInst(item_ref);
+                    const item_val = sema.resolveConstValue(block, .unneeded, item, undefined) catch unreachable;
+                    const field_ty = maybe_union_ty.unionFieldType(item_val, sema.mod);
+                    if (field_ty.zigTypeTag() != .NoReturn) break true;
+                } else false
+            else
+                true;
+
             const body = sema.code.extra[extra_index..][0..body_len];
             extra_index += body_len;
-            _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) {
-                error.ComptimeBreak => {
-                    const zir_datas = sema.code.instructions.items(.data);
-                    const break_data = zir_datas[sema.comptime_break_inst].@"break";
-                    try sema.addRuntimeBreak(&case_block, .{
-                        .block_inst = break_data.block_inst,
-                        .operand = break_data.operand,
-                        .inst = sema.comptime_break_inst,
-                    });
-                },
-                else => |e| return e,
-            };
+            if (analyze_body) {
+                _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) {
+                    error.ComptimeBreak => {
+                        const zir_datas = sema.code.instructions.items(.data);
+                        const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                        try sema.addRuntimeBreak(&case_block, .{
+                            .block_inst = break_data.block_inst,
+                            .operand = break_data.operand,
+                            .inst = sema.comptime_break_inst,
+                        });
+                    },
+                    else => |e| return e,
+                };
+            } else {
+                _ = try case_block.addNoOp(.unreach);
+            }
 
             try cases_extra.ensureUnusedCapacity(gpa, 2 + items.len +
                 case_block.instructions.items.len);
@@ -9824,7 +9857,17 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         case_block.instructions.shrinkRetainingCapacity(0);
         case_block.wip_capture_scope = wip_captures.scope;
 
-        if (special.body.len != 0) {
+        const analyze_body = if (union_originally)
+            for (seen_union_fields) |seen_field, index| {
+                if (seen_field != null) continue;
+                const union_obj = maybe_union_ty.cast(Type.Payload.Union).?.data;
+                const field_ty = union_obj.fields.values()[index].ty;
+                if (field_ty.zigTypeTag() != .NoReturn) break true;
+            } else false
+        else
+            true;
+
+        if (special.body.len != 0 and analyze_body) {
             _ = sema.analyzeBodyInner(&case_block, special.body) catch |err| switch (err) {
                 error.ComptimeBreak => {
                     const zir_datas = sema.code.instructions.items(.data);
@@ -13225,6 +13268,14 @@ fn analyzeCmpUnionTag(
     const coerced_tag = try sema.coerce(block, union_tag_ty, tag, tag_src);
     const coerced_union = try sema.coerce(block, union_tag_ty, un, un_src);
 
+    if (try sema.resolveMaybeUndefVal(block, tag_src, coerced_tag)) |enum_val| {
+        if (enum_val.isUndef()) return sema.addConstUndef(Type.bool);
+        const field_ty = union_ty.unionFieldType(enum_val, sema.mod);
+        if (field_ty.zigTypeTag() == .NoReturn) {
+            return Air.Inst.Ref.bool_false;
+        }
+    }
+
     return sema.cmpSelf(block, src, coerced_union, coerced_tag, op, un_src, tag_src);
 }
 
@@ -15579,6 +15630,8 @@ fn finishStructInit(
     const gpa = sema.gpa;
 
     var root_msg: ?*Module.ErrorMsg = null;
+    errdefer if (root_msg) |msg| msg.destroy(sema.gpa);
+
     if (struct_ty.isAnonStruct()) {
         const struct_obj = struct_ty.castTag(.anon_struct).?.data;
         for (struct_obj.values) |default_val, i| {
@@ -15634,6 +15687,7 @@ fn finishStructInit(
     }
 
     if (root_msg) |msg| {
+        root_msg = null;
         if (struct_ty.castTag(.@"struct")) |struct_obj| {
             const fqn = try struct_obj.data.getFullyQualifiedName(sema.mod);
             defer gpa.free(fqn);
@@ -21682,6 +21736,18 @@ fn unionFieldPtr(
         .@"addrspace" = union_ptr_ty.ptrAddressSpace(),
     });
 
+    if (initializing and field.ty.zigTypeTag() == .NoReturn) {
+        const msg = msg: {
+            const msg = try sema.errMsg(block, src, "cannot initialize 'noreturn' field of union", .{});
+            errdefer msg.destroy(sema.gpa);
+
+            try sema.addFieldErrNote(block, union_ty, field_index, msg, "field '{s}' declared here", .{field_name});
+            try sema.addDeclaredHereNote(msg, union_ty);
+            break :msg msg;
+        };
+        return sema.failWithOwnedErrorMsg(msg);
+    }
+
     if (try sema.resolveDefinedValue(block, src, union_ptr)) |union_ptr_val| ct: {
         switch (union_obj.layout) {
             .Auto => if (!initializing) {
@@ -21734,6 +21800,10 @@ fn unionFieldPtr(
         const ok = try block.addBinOp(.cmp_eq, active_tag, wanted_tag);
         try sema.addSafetyCheck(block, ok, .inactive_union_field);
     }
+    if (field.ty.zigTypeTag() == .NoReturn) {
+        _ = try block.addNoOp(.unreach);
+        return Air.Inst.Ref.unreachable_value;
+    }
     return block.addStructFieldPtr(union_ptr, field_index, ptr_field_ty);
 }
 
@@ -21802,6 +21872,10 @@ fn unionFieldVal(
         const ok = try block.addBinOp(.cmp_eq, active_tag, wanted_tag);
         try sema.addSafetyCheck(block, ok, .inactive_union_field);
     }
+    if (field.ty.zigTypeTag() == .NoReturn) {
+        _ = try block.addNoOp(.unreach);
+        return Air.Inst.Ref.unreachable_value;
+    }
     return block.addStructFieldVal(union_byval, field_index, field.ty);
 }
 
@@ -25002,6 +25076,18 @@ fn coerceEnumToUnion(
         };
         const field = union_obj.fields.values()[field_index];
         const field_ty = try sema.resolveTypeFields(block, inst_src, field.ty);
+        if (field_ty.zigTypeTag() == .NoReturn) {
+            const msg = msg: {
+                const msg = try sema.errMsg(block, inst_src, "cannot initialize 'noreturn' field of union", .{});
+                errdefer msg.destroy(sema.gpa);
+
+                const field_name = union_obj.fields.keys()[field_index];
+                try sema.addFieldErrNote(block, union_ty, field_index, msg, "field '{s}' declared here", .{field_name});
+                try sema.addDeclaredHereNote(msg, union_ty);
+                break :msg msg;
+            };
+            return sema.failWithOwnedErrorMsg(msg);
+        }
         const opv = (try sema.typeHasOnePossibleValue(block, inst_src, field_ty)) orelse {
             const msg = msg: {
                 const field_name = union_obj.fields.keys()[field_index];
@@ -25037,13 +25123,37 @@ fn coerceEnumToUnion(
         return sema.failWithOwnedErrorMsg(msg);
     }
 
+    const union_obj = union_ty.cast(Type.Payload.Union).?.data;
+    {
+        var msg: ?*Module.ErrorMsg = null;
+        errdefer if (msg) |some| some.destroy(sema.gpa);
+
+        for (union_obj.fields.values()) |field, i| {
+            if (field.ty.zigTypeTag() == .NoReturn) {
+                const err_msg = msg orelse try sema.errMsg(
+                    block,
+                    inst_src,
+                    "runtime coercion from enum '{}' to union '{}' which has a 'noreturn' field",
+                    .{ tag_ty.fmt(sema.mod), union_ty.fmt(sema.mod) },
+                );
+                msg = err_msg;
+
+                try sema.addFieldErrNote(block, union_ty, i, err_msg, "'noreturn' field here", .{});
+            }
+        }
+        if (msg) |some| {
+            msg = null;
+            try sema.addDeclaredHereNote(some, union_ty);
+            return sema.failWithOwnedErrorMsg(some);
+        }
+    }
+
     // If the union has all fields 0 bits, the union value is just the enum value.
     if (union_ty.unionHasAllZeroBitFieldTypes()) {
         return block.addBitCast(union_ty, enum_tag);
     }
 
     const msg = msg: {
-        const union_obj = union_ty.cast(Type.Payload.Union).?.data;
         const msg = try sema.errMsg(
             block,
             inst_src,
@@ -25054,11 +25164,11 @@ fn coerceEnumToUnion(
 
         var it = union_obj.fields.iterator();
         var field_index: usize = 0;
-        while (it.next()) |field| {
+        while (it.next()) |field| : (field_index += 1) {
             const field_name = field.key_ptr.*;
             const field_ty = field.value_ptr.ty;
+            if (!field_ty.hasRuntimeBits()) continue;
             try sema.addFieldErrNote(block, union_ty, field_index, msg, "field '{s}' has type '{}'", .{ field_name, field_ty.fmt(sema.mod) });
-            field_index += 1;
         }
         try sema.addDeclaredHereNote(msg, union_ty);
         break :msg msg;
@@ -25361,6 +25471,7 @@ fn coerceTupleToStruct(
 
     // Populate default field values and report errors for missing fields.
     var root_msg: ?*Module.ErrorMsg = null;
+    errdefer if (root_msg) |msg| msg.destroy(sema.gpa);
 
     for (field_refs) |*field_ref, i| {
         if (field_ref.* != .none) continue;
@@ -25386,6 +25497,7 @@ fn coerceTupleToStruct(
     }
 
     if (root_msg) |msg| {
+        root_msg = null;
         try sema.addDeclaredHereNote(msg, struct_ty);
         return sema.failWithOwnedErrorMsg(msg);
     }
@@ -25455,6 +25567,7 @@ fn coerceTupleToTuple(
 
     // Populate default field values and report errors for missing fields.
     var root_msg: ?*Module.ErrorMsg = null;
+    errdefer if (root_msg) |msg| msg.destroy(sema.gpa);
 
     for (field_refs) |*field_ref, i| {
         if (field_ref.* != .none) continue;
@@ -25490,6 +25603,7 @@ fn coerceTupleToTuple(
     }
 
     if (root_msg) |msg| {
+        root_msg = null;
         try sema.addDeclaredHereNote(msg, tuple_ty);
         return sema.failWithOwnedErrorMsg(msg);
     }
@@ -27914,6 +28028,18 @@ fn semaStructFields(mod: *Module, struct_obj: *Module.Struct) CompileError!void
             };
             return sema.failWithOwnedErrorMsg(msg);
         }
+        if (field_ty.zigTypeTag() == .NoReturn) {
+            const msg = msg: {
+                const tree = try sema.getAstTree(&block_scope);
+                const field_src = enumFieldSrcLoc(decl, tree.*, 0, i);
+                const msg = try sema.errMsg(&block_scope, field_src, "struct fields cannot be 'noreturn'", .{});
+                errdefer msg.destroy(sema.gpa);
+
+                try sema.addDeclaredHereNote(msg, field_ty);
+                break :msg msg;
+            };
+            return sema.failWithOwnedErrorMsg(msg);
+        }
         if (struct_obj.layout == .Extern and !sema.validateExternType(field.ty, .other)) {
             const msg = msg: {
                 const tree = try sema.getAstTree(&block_scope);
@@ -28725,6 +28851,16 @@ fn enumFieldSrcLoc(
         .container_decl_arg_trailing,
         => tree.containerDeclArg(enum_node),
 
+        .tagged_union,
+        .tagged_union_trailing,
+        => tree.taggedUnion(enum_node),
+        .tagged_union_two,
+        .tagged_union_two_trailing,
+        => tree.taggedUnionTwo(&buffer, enum_node),
+        .tagged_union_enum_tag,
+        .tagged_union_enum_tag_trailing,
+        => tree.taggedUnionEnumTag(enum_node),
+
         // Container was constructed with `@Type`.
         else => return LazySrcLoc.nodeOffset(0),
     };
@@ -29375,7 +29511,9 @@ fn unionFieldAlignment(
     src: LazySrcLoc,
     field: Module.Union.Field,
 ) !u32 {
-    if (field.abi_align == 0) {
+    if (field.ty.zigTypeTag() == .NoReturn) {
+        return @as(u32, 0);
+    } else if (field.abi_align == 0) {
         return sema.typeAbiAlignment(block, src, field.ty);
     } else {
         return field.abi_align;
test/behavior/union.zig
@@ -1256,3 +1256,48 @@ test "return an extern union from C calling convention" {
     });
     try expect(u.d == 4.0);
 }
+
+test "noreturn field in union" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+
+    const U = union(enum) {
+        a: u32,
+        b: noreturn,
+        c: noreturn,
+    };
+    var a = U{ .a = 1 };
+    var count: u32 = 0;
+    if (a == .b) @compileError("bad");
+    switch (a) {
+        .a => count += 1,
+        .b => |val| {
+            _ = val;
+            @compileError("bad");
+        },
+        .c => @compileError("bad"),
+    }
+    switch (a) {
+        .a => count += 1,
+        .b, .c => @compileError("bad"),
+    }
+    switch (a) {
+        .a, .b, .c => {
+            count += 1;
+            try expect(a == .a);
+        },
+    }
+    switch (a) {
+        .a => count += 1,
+        else => @compileError("bad"),
+    }
+    switch (a) {
+        else => {
+            count += 1;
+            try expect(a == .a);
+        },
+    }
+    try expect(count == 5);
+}
test/cases/compile_errors/noreturn_struct_field.zig
@@ -0,0 +1,12 @@
+const S = struct {
+    s: noreturn,
+};
+comptime {
+    _ = @typeInfo(S);
+}
+
+// error
+// backend=stage2
+// target=native
+//
+// :2:5: error: struct fields cannot be 'noreturn'
test/cases/compile_errors/runtime_cast_to_union_which_has_non-void_fields.zig
@@ -18,6 +18,4 @@ fn foo(l: Letter) void {
 //
 // :11:20: error: runtime coercion from enum 'tmp.Letter' to union 'tmp.Value' which has non-void fields
 // :3:5: note: field 'A' has type 'i32'
-// :4:5: note: field 'B' has type 'void'
-// :5:5: note: field 'C' has type 'void'
 // :2:15: note: union declared here
test/cases/compile_errors/union_noreturn_field_initialized.zig
@@ -0,0 +1,43 @@
+pub export fn entry1() void {
+    const U = union(enum) {
+        a: u32,
+        b: noreturn,
+        fn foo(_: @This()) void {}
+        fn bar() noreturn {
+            unreachable;
+        }
+    };
+
+    var a = U{ .b = undefined };
+    _ = a;
+}
+pub export fn entry2() void {
+    const U = union(enum) {
+        a: noreturn,
+    };
+    var u: U = undefined;
+    u = .a;
+}
+pub export fn entry3() void {
+    const U = union(enum) {
+        a: noreturn,
+        b: void,
+    };
+    var e = @typeInfo(U).Union.tag_type.?.a;
+    var u: U = undefined;
+    u = e;
+}
+
+// error
+// backend=stage2
+// target=native
+//
+// :11:21: error: cannot initialize 'noreturn' field of union
+// :4:9: note: field 'b' declared here
+// :2:15: note: union declared here
+// :19:10: error: cannot initialize 'noreturn' field of union
+// :16:9: note: field 'a' declared here
+// :15:15: note: union declared here
+// :28:9: error: runtime coercion from enum '@typeInfo(tmp.entry3.U).Union.tag_type.?' to union 'tmp.entry3.U' which has a 'noreturn' field
+// :23:9: note: 'noreturn' field here
+// :22:15: note: union declared here