Commit 1705a21f80

Veikka Tuominen <git@vexu.eu>
2022-07-19 00:25:10
Sema: more union and enum tag type validation
1 parent 8feb398
src/AstGen.zig
@@ -4299,7 +4299,7 @@ fn unionDeclInner(
     members: []const Ast.Node.Index,
     layout: std.builtin.Type.ContainerLayout,
     arg_node: Ast.Node.Index,
-    have_auto_enum: bool,
+    auto_enum_tok: ?Ast.TokenIndex,
 ) InnerError!Zir.Inst.Ref {
     const decl_inst = try gz.reserveInstructionIndex();
 
@@ -4333,6 +4333,15 @@ fn unionDeclInner(
     const decl_count = try astgen.scanDecls(&namespace, members);
     const field_count = @intCast(u32, members.len - decl_count);
 
+    if (layout != .Auto and (auto_enum_tok != null or arg_node != 0)) {
+        const layout_str = if (layout == .Extern) "extern" else "packed";
+        if (arg_node != 0) {
+            return astgen.failNode(arg_node, "{s} union does not support enum tag type", .{layout_str});
+        } else {
+            return astgen.failTok(auto_enum_tok.?, "{s} union does not support enum tag type", .{layout_str});
+        }
+    }
+
     const arg_inst: Zir.Inst.Ref = if (arg_node != 0)
         try typeExpr(&block_scope, &namespace.base, arg_node)
     else
@@ -4367,7 +4376,7 @@ fn unionDeclInner(
         if (have_type) {
             const field_type = try typeExpr(&block_scope, &namespace.base, member.ast.type_expr);
             wip_members.appendToField(@enumToInt(field_type));
-        } else if (arg_inst == .none and !have_auto_enum) {
+        } else if (arg_inst == .none and auto_enum_tok == null) {
             return astgen.failNode(member_node, "union field missing type", .{});
         }
         if (have_align) {
@@ -4389,7 +4398,7 @@ fn unionDeclInner(
                     },
                 );
             }
-            if (!have_auto_enum) {
+            if (auto_enum_tok == null) {
                 return astgen.failNodeNotes(
                     node,
                     "explicitly valued tagged union requires inferred enum tag type",
@@ -4425,7 +4434,7 @@ fn unionDeclInner(
         .body_len = body_len,
         .fields_len = field_count,
         .decls_len = decl_count,
-        .auto_enum_tag = have_auto_enum,
+        .auto_enum_tag = auto_enum_tok != null,
     });
 
     wip_members.finishBits(bits_per_field);
@@ -4481,9 +4490,7 @@ fn containerDecl(
                 else => unreachable,
             } else std.builtin.Type.ContainerLayout.Auto;
 
-            const have_auto_enum = container_decl.ast.enum_token != null;
-
-            const result = try unionDeclInner(gz, scope, node, container_decl.ast.members, layout, container_decl.ast.arg, have_auto_enum);
+            const result = try unionDeclInner(gz, scope, node, container_decl.ast.members, layout, container_decl.ast.arg, container_decl.ast.enum_token);
             return rvalue(gz, rl, result, node);
         },
         .keyword_enum => {
src/Module.zig
@@ -2626,6 +2626,29 @@ pub const SrcLoc = struct {
                 };
                 return nodeToSpan(tree, full.ast.bit_range_end);
             },
+            .node_offset_container_tag => |node_off| {
+                const tree = try src_loc.file_scope.getTree(gpa);
+                const node_tags = tree.nodes.items(.tag);
+                const parent_node = src_loc.declRelativeToNodeIndex(node_off);
+
+                switch (node_tags[parent_node]) {
+                    .container_decl_arg, .container_decl_arg_trailing => {
+                        const full = tree.containerDeclArg(parent_node);
+                        return nodeToSpan(tree, full.ast.arg);
+                    },
+                    .tagged_union_enum_tag, .tagged_union_enum_tag_trailing => {
+                        const full = tree.taggedUnionEnumTag(parent_node);
+
+                        return tokensToSpan(
+                            tree,
+                            tree.firstToken(full.ast.arg) - 2,
+                            tree.lastToken(full.ast.arg) + 1,
+                            tree.nodes.items(.main_token)[full.ast.arg],
+                        );
+                    },
+                    else => unreachable,
+                }
+            },
         }
     }
 
@@ -2935,6 +2958,9 @@ pub const LazySrcLoc = union(enum) {
     /// The source location points to the host size of a pointer.
     /// The Decl is determined contextually.
     node_offset_ptr_hostsize: i32,
+    /// The source location points to the tag type of an union or an enum.
+    /// The Decl is determined contextually.
+    node_offset_container_tag: i32,
 
     pub const nodeOffset = if (TracedOffset.want_tracing) nodeOffsetDebug else nodeOffsetRelease;
 
@@ -3008,6 +3034,7 @@ pub const LazySrcLoc = union(enum) {
             .node_offset_ptr_addrspace,
             .node_offset_ptr_bitoffset,
             .node_offset_ptr_hostsize,
+            .node_offset_container_tag,
             => .{
                 .file_scope = decl.getFileScope(),
                 .parent_decl_node = decl.src_node,
@@ -4711,7 +4738,7 @@ pub fn scanNamespace(
     extra_start: usize,
     decls_len: u32,
     parent_decl: *Decl,
-) SemaError!usize {
+) Allocator.Error!usize {
     const tracy = trace(@src());
     defer tracy.end();
 
@@ -4758,7 +4785,7 @@ const ScanDeclIter = struct {
     unnamed_test_index: usize = 0,
 };
 
-fn scanDecl(iter: *ScanDeclIter, decl_sub_index: usize, flags: u4) SemaError!void {
+fn scanDecl(iter: *ScanDeclIter, decl_sub_index: usize, flags: u4) Allocator.Error!void {
     const tracy = trace(@src());
     defer tracy.end();
 
src/Sema.zig
@@ -2344,6 +2344,7 @@ fn zirEnumDecl(
         extra_index += 1;
         break :blk LazySrcLoc.nodeOffset(node_offset);
     } else sema.src;
+    const tag_ty_src: LazySrcLoc = .{ .node_offset_container_tag = src.node_offset.x };
 
     const tag_type_ref = if (small.has_tag_type) blk: {
         const tag_type_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]);
@@ -2369,8 +2370,10 @@ fn zirEnumDecl(
         break :blk decls_len;
     } else 0;
 
+    var done = false;
+
     var new_decl_arena = std.heap.ArenaAllocator.init(gpa);
-    errdefer new_decl_arena.deinit();
+    errdefer if (!done) new_decl_arena.deinit();
     const new_decl_arena_allocator = new_decl_arena.allocator();
 
     const enum_obj = try new_decl_arena_allocator.create(Module.EnumFull);
@@ -2387,7 +2390,7 @@ fn zirEnumDecl(
     }, small.name_strategy, "enum", inst);
     const new_decl = mod.declPtr(new_decl_index);
     new_decl.owns_tv = true;
-    errdefer mod.abortAnonDecl(new_decl_index);
+    errdefer if (!done) mod.abortAnonDecl(new_decl_index);
 
     enum_obj.* = .{
         .owner_decl = new_decl_index,
@@ -2406,19 +2409,28 @@ fn zirEnumDecl(
         &enum_obj.namespace, new_decl, new_decl.name,
     });
 
+    try new_decl.finalizeNewArena(&new_decl_arena);
+    const decl_val = try sema.analyzeDeclVal(block, src, new_decl_index);
+    done = true;
+
+    var decl_arena = new_decl.value_arena.?.promote(gpa);
+    defer new_decl.value_arena.?.* = decl_arena.state;
+    const decl_arena_allocator = decl_arena.allocator();
+
     extra_index = try mod.scanNamespace(&enum_obj.namespace, extra_index, decls_len, new_decl);
 
     const body = sema.code.extra[extra_index..][0..body_len];
     if (fields_len == 0) {
         assert(body.len == 0);
         if (tag_type_ref != .none) {
-            // TODO better source location
-            const ty = try sema.resolveType(block, src, tag_type_ref);
+            const ty = try sema.resolveType(block, tag_ty_src, tag_type_ref);
+            if (ty.zigTypeTag() != .Int and ty.zigTypeTag() != .ComptimeInt) {
+                return sema.fail(block, tag_ty_src, "expected integer tag type, found '{}'", .{ty.fmt(sema.mod)});
+            }
             enum_obj.tag_ty = try ty.copy(new_decl_arena_allocator);
             enum_obj.tag_ty_inferred = false;
         }
-        try new_decl.finalizeNewArena(&new_decl_arena);
-        return sema.analyzeDeclVal(block, src, new_decl_index);
+        return decl_val;
     }
     extra_index += body.len;
 
@@ -2471,13 +2483,15 @@ fn zirEnumDecl(
         try wip_captures.finalize();
 
         if (tag_type_ref != .none) {
-            // TODO better source location
-            const ty = try sema.resolveType(block, src, tag_type_ref);
-            enum_obj.tag_ty = try ty.copy(new_decl_arena_allocator);
+            const ty = try sema.resolveType(block, tag_ty_src, tag_type_ref);
+            if (ty.zigTypeTag() != .Int and ty.zigTypeTag() != .ComptimeInt) {
+                return sema.fail(block, tag_ty_src, "expected integer tag type, found '{}'", .{ty.fmt(sema.mod)});
+            }
+            enum_obj.tag_ty = try ty.copy(decl_arena_allocator);
             enum_obj.tag_ty_inferred = false;
         } else {
             const bits = std.math.log2_int_ceil(usize, fields_len);
-            enum_obj.tag_ty = try Type.Tag.int_unsigned.create(new_decl_arena_allocator, bits);
+            enum_obj.tag_ty = try Type.Tag.int_unsigned.create(decl_arena_allocator, bits);
             enum_obj.tag_ty_inferred = true;
         }
     }
@@ -2488,12 +2502,12 @@ fn zirEnumDecl(
         }
     }
 
-    try enum_obj.fields.ensureTotalCapacity(new_decl_arena_allocator, fields_len);
+    try enum_obj.fields.ensureTotalCapacity(decl_arena_allocator, fields_len);
     const any_values = for (sema.code.extra[body_end..][0..bit_bags_count]) |bag| {
         if (bag != 0) break true;
     } else false;
     if (any_values) {
-        try enum_obj.values.ensureTotalCapacityContext(new_decl_arena_allocator, fields_len, .{
+        try enum_obj.values.ensureTotalCapacityContext(decl_arena_allocator, fields_len, .{
             .ty = enum_obj.tag_ty,
             .mod = mod,
         });
@@ -2518,7 +2532,7 @@ fn zirEnumDecl(
         extra_index += 1;
 
         // This string needs to outlive the ZIR code.
-        const field_name = try new_decl_arena_allocator.dupe(u8, field_name_zir);
+        const field_name = try decl_arena_allocator.dupe(u8, field_name_zir);
 
         const gop = enum_obj.fields.getOrPutAssumeCapacity(field_name);
         if (gop.found_existing) {
@@ -2542,7 +2556,7 @@ fn zirEnumDecl(
             // But only resolve the source location if we need to emit a compile error.
             const tag_val = (try sema.resolveInstConst(block, src, tag_val_ref, "enum tag value must be comptime known")).val;
             last_tag_val = tag_val;
-            const copied_tag_val = try tag_val.copy(new_decl_arena_allocator);
+            const copied_tag_val = try tag_val.copy(decl_arena_allocator);
             enum_obj.values.putAssumeCapacityNoClobberContext(copied_tag_val, {}, .{
                 .ty = enum_obj.tag_ty,
                 .mod = mod,
@@ -2553,16 +2567,14 @@ fn zirEnumDecl(
             else
                 Value.zero;
             last_tag_val = tag_val;
-            const copied_tag_val = try tag_val.copy(new_decl_arena_allocator);
+            const copied_tag_val = try tag_val.copy(decl_arena_allocator);
             enum_obj.values.putAssumeCapacityNoClobberContext(copied_tag_val, {}, .{
                 .ty = enum_obj.tag_ty,
                 .mod = mod,
             });
         }
     }
-
-    try new_decl.finalizeNewArena(&new_decl_arena);
-    return sema.analyzeDeclVal(block, src, new_decl_index);
+    return decl_val;
 }
 
 fn zirUnionDecl(
@@ -8551,11 +8563,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                         if (seen_src != null) continue;
 
                         const field_name = operand_ty.enumFieldName(i);
-
-                        const field_src = src; // TODO better source location
-                        try sema.errNote(
+                        try sema.addFieldErrNote(
                             block,
-                            field_src,
+                            operand_ty,
+                            i,
                             msg,
                             "unhandled enumeration value: '{s}'",
                             .{field_name},
@@ -10587,7 +10598,7 @@ fn zirOverflowArithmetic(
     try sema.checkVectorizableBinaryOperands(block, src, lhs_ty, rhs_ty, lhs_src, rhs_src);
     const dest_ty = lhs_ty;
     if (dest_ty.scalarType().zigTypeTag() != .Int) {
-        return sema.fail(block, src, "expected vector of integers or integer type, found '{}'", .{dest_ty.fmt(mod)});
+        return sema.fail(block, src, "expected vector of integers or integer tag type, found '{}'", .{dest_ty.fmt(mod)});
     }
 
     const maybe_lhs_val = try sema.resolveMaybeUndefVal(block, lhs_src, lhs);
@@ -25175,7 +25186,7 @@ fn resolveTypeFieldsUnion(
     }
 
     union_obj.status = .field_types_wip;
-    try semaUnionFields(block, sema.mod, union_obj);
+    try semaUnionFields(sema.mod, union_obj);
     union_obj.status = .have_field_types;
 }
 
@@ -25462,7 +25473,7 @@ fn semaStructFields(mod: *Module, struct_obj: *Module.Struct) CompileError!void
     struct_obj.have_field_inits = true;
 }
 
-fn semaUnionFields(block: *Block, mod: *Module, union_obj: *Module.Union) CompileError!void {
+fn semaUnionFields(mod: *Module, union_obj: *Module.Union) CompileError!void {
     const tracy = trace(@src());
     defer tracy.end();
 
@@ -25567,10 +25578,14 @@ fn semaUnionFields(block: *Block, mod: *Module, union_obj: *Module.Union) Compil
     var enum_value_map: ?*Module.EnumNumbered.ValueMap = null;
     var tag_ty_field_names: ?Module.EnumFull.NameMap = null;
     if (tag_type_ref != .none) {
-        const provided_ty = try sema.resolveType(&block_scope, src, tag_type_ref);
+        const tag_ty_src: LazySrcLoc = .{ .node_offset_container_tag = src.node_offset.x };
+        const provided_ty = try sema.resolveType(&block_scope, tag_ty_src, tag_type_ref);
         if (small.auto_enum_tag) {
             // The provided type is an integer type and we must construct the enum tag type here.
             int_tag_ty = provided_ty;
+            if (int_tag_ty.zigTypeTag() != .Int and int_tag_ty.zigTypeTag() != .ComptimeInt) {
+                return sema.fail(&block_scope, tag_ty_src, "expected integer tag type, found '{}'", .{int_tag_ty.fmt(sema.mod)});
+            }
             union_obj.tag_ty = try sema.generateUnionTagTypeNumbered(&block_scope, fields_len, provided_ty, union_obj);
             const enum_obj = union_obj.tag_ty.castTag(.enum_numbered).?.data;
             enum_field_names = &enum_obj.fields;
@@ -25579,8 +25594,7 @@ fn semaUnionFields(block: *Block, mod: *Module, union_obj: *Module.Union) Compil
             // The provided type is the enum tag type.
             union_obj.tag_ty = try provided_ty.copy(decl_arena_allocator);
             if (union_obj.tag_ty.zigTypeTag() != .Enum) {
-                const tag_ty_src = src; // TODO better source location
-                return sema.fail(block, tag_ty_src, "expected enum tag type, found '{}'", .{union_obj.tag_ty.fmt(sema.mod)});
+                return sema.fail(&block_scope, tag_ty_src, "expected enum tag type, found '{}'", .{union_obj.tag_ty.fmt(sema.mod)});
             }
             // The fields of the union must match the enum exactly.
             // Store a copy of the enum field names so we can check for
@@ -25658,7 +25672,7 @@ fn semaUnionFields(block: *Block, mod: *Module, union_obj: *Module.Union) Compil
                 });
             } else {
                 const val = if (last_tag_val) |val|
-                    try sema.intAdd(block, src, val, Value.one, int_tag_ty)
+                    try sema.intAdd(&block_scope, src, val, Value.one, int_tag_ty)
                 else
                     Value.zero;
                 last_tag_val = val;
@@ -25712,12 +25726,14 @@ fn semaUnionFields(block: *Block, mod: *Module, union_obj: *Module.Union) Compil
             const enum_has_field = names.orderedRemove(field_name);
             if (!enum_has_field) {
                 const msg = msg: {
-                    const msg = try sema.errMsg(block, src, "enum '{}' has no field named '{s}'", .{ union_obj.tag_ty.fmt(sema.mod), field_name });
+                    const tree = try sema.getAstTree(&block_scope);
+                    const field_src = enumFieldSrcLoc(decl, tree.*, union_obj.node_offset, field_i);
+                    const msg = try sema.errMsg(&block_scope, field_src, "enum '{}' has no field named '{s}'", .{ union_obj.tag_ty.fmt(sema.mod), field_name });
                     errdefer msg.destroy(sema.gpa);
                     try sema.addDeclaredHereNote(msg, union_obj.tag_ty);
                     break :msg msg;
                 };
-                return sema.failWithOwnedErrorMsg(block, msg);
+                return sema.failWithOwnedErrorMsg(&block_scope, msg);
             }
         }
 
@@ -25739,18 +25755,18 @@ fn semaUnionFields(block: *Block, mod: *Module, union_obj: *Module.Union) Compil
     if (tag_ty_field_names) |names| {
         if (names.count() > 0) {
             const msg = msg: {
-                const msg = try sema.errMsg(block, src, "enum field(s) missing in union", .{});
+                const msg = try sema.errMsg(&block_scope, src, "enum field(s) missing in union", .{});
                 errdefer msg.destroy(sema.gpa);
 
                 const enum_ty = union_obj.tag_ty;
                 for (names.keys()) |field_name| {
                     const field_index = enum_ty.enumFieldIndex(field_name).?;
-                    try sema.addFieldErrNote(block, enum_ty, field_index, msg, "field '{s}' missing, declared here", .{field_name});
+                    try sema.addFieldErrNote(&block_scope, enum_ty, field_index, msg, "field '{s}' missing, declared here", .{field_name});
                 }
                 try sema.addDeclaredHereNote(msg, union_obj.tag_ty);
                 break :msg msg;
             };
-            return sema.failWithOwnedErrorMsg(block, msg);
+            return sema.failWithOwnedErrorMsg(&block_scope, msg);
         }
     }
 }
test/cases/compile_errors/stage2/union_enum_field_missing.zig
@@ -16,6 +16,6 @@ export fn entry() usize {
 // error
 // target=native
 //
-// :7:1: error: enum field(s) missing in union
+// :7:11: error: enum field(s) missing in union
 // :4:5: note: field 'c' missing, declared here
 // :1:11: note: enum declared here
test/cases/compile_errors/stage2/union_extra_field.zig
@@ -16,5 +16,5 @@ export fn entry() usize {
 // error
 // target=native
 //
-// :6:1: error: enum 'tmp.E' has no field named 'd'
+// :10:5: error: enum 'tmp.E' has no field named 'd'
 // :1:11: note: enum declared here
test/cases/compile_errors/stage1/obj/extern_union_given_enum_tag_type.zig โ†’ test/cases/compile_errors/extern_union_given_enum_tag_type.zig
@@ -14,7 +14,7 @@ export fn entry() void {
 }
 
 // error
-// backend=stage1
+// backend=stage2
 // target=native
 //
-// tmp.zig:6:30: error: extern union does not support enum tag type
+// :6:30: error: extern union does not support enum tag type
test/cases/compile_errors/stage1/obj/non-enum_tag_type_passed_to_union.zig โ†’ test/cases/compile_errors/non-enum_tag_type_passed_to_union.zig
@@ -7,7 +7,7 @@ export fn entry() void {
 }
 
 // error
-// backend=stage1
+// backend=stage2
 // target=native
 //
-// tmp.zig:1:19: error: expected enum tag type, found 'u32'
+// :1:19: error: expected enum tag type, found 'u32'
test/cases/compile_errors/stage1/obj/non-integer_tag_type_to_automatic_union_enum.zig โ†’ test/cases/compile_errors/non-integer_tag_type_to_automatic_union_enum.zig
@@ -7,7 +7,7 @@ export fn entry() void {
 }
 
 // error
-// backend=stage1
+// backend=stage2
 // target=native
 //
-// tmp.zig:1:24: error: expected integer tag type, found 'f32'
+// :1:24: error: expected integer tag type, found 'f32'
test/cases/compile_errors/non-integer_tag_type_to_enum.zig
@@ -0,0 +1,13 @@
+const Foo = enum(f32) {
+    A,
+};
+export fn entry() void {
+    var f: Foo = undefined;
+    _ = f;
+}
+
+// error
+// backend=stage2
+// target=native
+//
+// :1:18: error: expected integer tag type, found 'f32'
test/cases/compile_errors/switch_expression-missing_enumeration_prong.zig
@@ -19,5 +19,5 @@ export fn entry() usize { return @sizeOf(@TypeOf(&f)); }
 // target=native
 //
 // :8:5: error: switch must handle all possibilities
-// :8:5: note: unhandled enumeration value: 'Four'
+// :5:5: note: unhandled enumeration value: 'Four'
 // :1:16: note: enum 'tmp.Number' declared here
test/cases/compile_errors/switch_on_enum_with_1_field_with_no_prongs.zig
@@ -10,5 +10,5 @@ export fn entry() void {
 // target=native
 //
 // :5:5: error: switch must handle all possibilities
-// :5:5: note: unhandled enumeration value: 'M'
+// :1:20: note: unhandled enumeration value: 'M'
 // :1:13: note: enum 'tmp.Foo' declared here
test/cases/compile_errors/switching_with_non-exhaustive_enums.zig
@@ -35,7 +35,7 @@ pub export fn entry3() void {
 // target=native
 //
 // :12:5: error: switch must handle all possibilities
-// :12:5: note: unhandled enumeration value: 'b'
+// :3:5: note: unhandled enumeration value: 'b'
 // :1:11: note: enum 'tmp.E' declared here
 // :19:5: error: switch on non-exhaustive enum must include 'else' or '_' prong
 // :26:5: error: '_' prong only allowed when switching on non-exhaustive enums
test/cases/compile_errors/union_with_specified_enum_omits_field.zig
@@ -15,6 +15,6 @@ export fn entry() usize {
 // backend=stage2
 // target=native
 //
-// :6:1: error: enum field(s) missing in union
+// :6:17: error: enum field(s) missing in union
 // :4:5: note: field 'C' missing, declared here
 // :1:16: note: enum declared here
test/stage2/cbe.zig
@@ -772,7 +772,7 @@ pub fn addCases(ctx: *TestContext) !void {
             \\}
         , &.{
             ":4:5: error: switch must handle all possibilities",
-            ":4:5: note: unhandled enumeration value: 'b'",
+            ":1:21: note: unhandled enumeration value: 'b'",
             ":1:11: note: enum 'tmp.E' declared here",
         });