Commit a1ac2b95bb

Andrew Kelley <andrew@ziglang.org>
2021-04-21 06:48:18
AstGen: implement union decls
1 parent 971f3d9
Changed files (3)
src/AstGen.zig
@@ -1703,6 +1703,8 @@ fn unusedResultExpr(gz: *GenZir, scope: *Scope, statement: ast.Node.Index) Inner
             .struct_decl_packed,
             .struct_decl_extern,
             .union_decl,
+            .union_decl_packed,
+            .union_decl_extern,
             .enum_decl,
             .enum_decl_nonexhaustive,
             .opaque_decl,
@@ -2897,7 +2899,7 @@ fn structDeclInner(
         if (member.comptime_token) |comptime_token| {
             return astgen.failTok(comptime_token, "TODO implement comptime struct fields", .{});
         }
-        try fields_data.ensureCapacity(gpa, fields_data.items.len + 4);
+        try fields_data.ensureUnusedCapacity(gpa, 4);
 
         const field_name = try gz.identAsString(member.ast.name_token);
         fields_data.appendAssumeCapacity(field_name);
@@ -2969,6 +2971,229 @@ fn structDeclInner(
     return gz.indexToRef(decl_inst);
 }
 
+fn unionDeclInner(
+    gz: *GenZir,
+    scope: *Scope,
+    node: ast.Node.Index,
+    members: []const ast.Node.Index,
+    tag: Zir.Inst.Tag,
+    arg_inst: Zir.Inst.Ref,
+    have_auto_enum: bool,
+) InnerError!Zir.Inst.Ref {
+    const astgen = gz.astgen;
+    const gpa = astgen.gpa;
+    const tree = &astgen.file.tree;
+    const node_tags = tree.nodes.items(.tag);
+    const node_datas = tree.nodes.items(.data);
+
+    // The union_decl instruction introduces a scope in which the decls of the union
+    // are in scope, so that field types, alignments, and default value expressions
+    // can refer to decls within the union itself.
+    var block_scope: GenZir = .{
+        .parent = scope,
+        .decl_node_index = node,
+        .astgen = astgen,
+        .force_comptime = true,
+        .ref_start_index = gz.ref_start_index,
+    };
+    defer block_scope.instructions.deinit(gpa);
+
+    var wip_decls: WipDecls = .{};
+    defer wip_decls.deinit(gpa);
+
+    // We don't know which members are fields until we iterate, so cannot do
+    // an accurate ensureCapacity yet.
+    var fields_data = ArrayListUnmanaged(u32){};
+    defer fields_data.deinit(gpa);
+
+    const bits_per_field = 4;
+    const fields_per_u32 = 32 / bits_per_field;
+    // We only need this if there are greater than fields_per_u32 fields.
+    var bit_bag = ArrayListUnmanaged(u32){};
+    defer bit_bag.deinit(gpa);
+
+    var cur_bit_bag: u32 = 0;
+    var field_index: usize = 0;
+    for (members) |member_node| {
+        const member = switch (node_tags[member_node]) {
+            .container_field_init => tree.containerFieldInit(member_node),
+            .container_field_align => tree.containerFieldAlign(member_node),
+            .container_field => tree.containerField(member_node),
+
+            .fn_decl => {
+                const fn_proto = node_datas[member_node].lhs;
+                const body = node_datas[member_node].rhs;
+                switch (node_tags[fn_proto]) {
+                    .fn_proto_simple => {
+                        var params: [1]ast.Node.Index = undefined;
+                        try astgen.fnDecl(gz, &wip_decls, body, tree.fnProtoSimple(&params, fn_proto));
+                        continue;
+                    },
+                    .fn_proto_multi => {
+                        try astgen.fnDecl(gz, &wip_decls, body, tree.fnProtoMulti(fn_proto));
+                        continue;
+                    },
+                    .fn_proto_one => {
+                        var params: [1]ast.Node.Index = undefined;
+                        try astgen.fnDecl(gz, &wip_decls, body, tree.fnProtoOne(&params, fn_proto));
+                        continue;
+                    },
+                    .fn_proto => {
+                        try astgen.fnDecl(gz, &wip_decls, body, tree.fnProto(fn_proto));
+                        continue;
+                    },
+                    else => unreachable,
+                }
+            },
+            .fn_proto_simple => {
+                var params: [1]ast.Node.Index = undefined;
+                try astgen.fnDecl(gz, &wip_decls, 0, tree.fnProtoSimple(&params, member_node));
+                continue;
+            },
+            .fn_proto_multi => {
+                try astgen.fnDecl(gz, &wip_decls, 0, tree.fnProtoMulti(member_node));
+                continue;
+            },
+            .fn_proto_one => {
+                var params: [1]ast.Node.Index = undefined;
+                try astgen.fnDecl(gz, &wip_decls, 0, tree.fnProtoOne(&params, member_node));
+                continue;
+            },
+            .fn_proto => {
+                try astgen.fnDecl(gz, &wip_decls, 0, tree.fnProto(member_node));
+                continue;
+            },
+
+            .global_var_decl => {
+                try astgen.globalVarDecl(gz, scope, &wip_decls, member_node, tree.globalVarDecl(member_node));
+                continue;
+            },
+            .local_var_decl => {
+                try astgen.globalVarDecl(gz, scope, &wip_decls, member_node, tree.localVarDecl(member_node));
+                continue;
+            },
+            .simple_var_decl => {
+                try astgen.globalVarDecl(gz, scope, &wip_decls, member_node, tree.simpleVarDecl(member_node));
+                continue;
+            },
+            .aligned_var_decl => {
+                try astgen.globalVarDecl(gz, scope, &wip_decls, member_node, tree.alignedVarDecl(member_node));
+                continue;
+            },
+
+            .@"comptime" => {
+                try astgen.comptimeDecl(gz, scope, member_node);
+                continue;
+            },
+            .@"usingnamespace" => {
+                try astgen.usingnamespaceDecl(gz, scope, member_node);
+                continue;
+            },
+            .test_decl => {
+                try astgen.testDecl(gz, scope, member_node);
+                continue;
+            },
+            else => unreachable,
+        };
+        if (field_index % fields_per_u32 == 0 and field_index != 0) {
+            try bit_bag.append(gpa, cur_bit_bag);
+            cur_bit_bag = 0;
+        }
+        if (member.comptime_token) |comptime_token| {
+            return astgen.failTok(comptime_token, "union fields cannot be marked comptime", .{});
+        }
+        try fields_data.ensureUnusedCapacity(gpa, 4);
+
+        const field_name = try gz.identAsString(member.ast.name_token);
+        fields_data.appendAssumeCapacity(field_name);
+
+        const have_type = member.ast.type_expr != 0;
+        const have_align = member.ast.align_expr != 0;
+        const have_value = member.ast.value_expr != 0;
+        cur_bit_bag = (cur_bit_bag >> bits_per_field) |
+            (@as(u32, @boolToInt(have_type)) << 28) |
+            (@as(u32, @boolToInt(have_align)) << 29) |
+            (@as(u32, @boolToInt(have_value)) << 30) |
+            (@as(u32, @boolToInt(have_auto_enum)) << 31);
+
+        if (have_type) {
+            const field_type = try typeExpr(&block_scope, &block_scope.base, member.ast.type_expr);
+            fields_data.appendAssumeCapacity(@enumToInt(field_type));
+        }
+        if (have_align) {
+            const align_inst = try expr(&block_scope, &block_scope.base, .{ .ty = .u32_type }, member.ast.align_expr);
+            fields_data.appendAssumeCapacity(@enumToInt(align_inst));
+        }
+        if (have_value) {
+            if (arg_inst == .none) {
+                return astgen.failNodeNotes(
+                    node,
+                    "explicitly valued tagged union missing integer tag type",
+                    .{},
+                    &[_]u32{
+                        try astgen.errNoteNode(
+                            member.ast.value_expr,
+                            "tag value specified here",
+                            .{},
+                        ),
+                    },
+                );
+            }
+            const tag_value = try expr(&block_scope, &block_scope.base, .{ .ty = arg_inst }, member.ast.value_expr);
+            fields_data.appendAssumeCapacity(@enumToInt(tag_value));
+        }
+
+        field_index += 1;
+    }
+    if (field_index == 0) {
+        return astgen.failNode(node, "union declarations must have at least one tag", .{});
+    }
+    {
+        const empty_slot_count = fields_per_u32 - (field_index % fields_per_u32);
+        if (empty_slot_count < fields_per_u32) {
+            cur_bit_bag >>= @intCast(u5, empty_slot_count * bits_per_field);
+        }
+    }
+    {
+        const empty_slot_count = 16 - (wip_decls.decl_index % 16);
+        if (empty_slot_count < 16) {
+            wip_decls.cur_bit_bag >>= @intCast(u5, empty_slot_count * 2);
+        }
+    }
+
+    const decl_inst = try gz.addBlock(tag, node);
+    try gz.instructions.append(gpa, decl_inst);
+    if (block_scope.instructions.items.len != 0) {
+        _ = try block_scope.addBreak(.break_inline, decl_inst, .void_value);
+    }
+
+    try astgen.extra.ensureUnusedCapacity(gpa, @typeInfo(Zir.Inst.UnionDecl).Struct.fields.len +
+        bit_bag.items.len + 1 + fields_data.items.len +
+        block_scope.instructions.items.len +
+        wip_decls.bit_bag.items.len + @boolToInt(wip_decls.decl_index != 0) +
+        wip_decls.name_and_value.items.len);
+    const zir_datas = astgen.instructions.items(.data);
+    zir_datas[decl_inst].pl_node.payload_index = astgen.addExtraAssumeCapacity(Zir.Inst.UnionDecl{
+        .tag_type = arg_inst,
+        .body_len = @intCast(u32, block_scope.instructions.items.len),
+        .fields_len = @intCast(u32, field_index),
+        .decls_len = @intCast(u32, wip_decls.decl_index),
+    });
+    astgen.extra.appendSliceAssumeCapacity(block_scope.instructions.items);
+
+    astgen.extra.appendSliceAssumeCapacity(bit_bag.items); // Likely empty.
+    astgen.extra.appendAssumeCapacity(cur_bit_bag);
+    astgen.extra.appendSliceAssumeCapacity(fields_data.items);
+
+    astgen.extra.appendSliceAssumeCapacity(wip_decls.bit_bag.items); // Likely empty.
+    if (wip_decls.decl_index != 0) {
+        astgen.extra.appendAssumeCapacity(wip_decls.cur_bit_bag);
+    }
+    astgen.extra.appendSliceAssumeCapacity(wip_decls.name_and_value.items);
+
+    return gz.indexToRef(decl_inst);
+}
+
 fn containerDecl(
     gz: *GenZir,
     scope: *Scope,
@@ -3005,7 +3230,18 @@ fn containerDecl(
             return rvalue(gz, scope, rl, result, node);
         },
         .keyword_union => {
-            return astgen.failTok(container_decl.ast.main_token, "TODO AstGen for union decl", .{});
+            const tag = if (container_decl.layout_token) |t| switch (token_tags[t]) {
+                .keyword_packed => Zir.Inst.Tag.union_decl_packed,
+                .keyword_extern => Zir.Inst.Tag.union_decl_extern,
+                else => unreachable,
+            } else Zir.Inst.Tag.union_decl;
+
+            // See `Zir.Inst.UnionDecl` doc comments for why this is stored along
+            // with fields instead of separately.
+            const have_auto_enum = container_decl.ast.enum_token != null;
+
+            const result = try unionDeclInner(gz, scope, node, container_decl.ast.members, tag, arg_inst, have_auto_enum);
+            return rvalue(gz, scope, rl, result, node);
         },
         .keyword_enum => {
             if (container_decl.layout_token) |t| {
@@ -3224,6 +3460,20 @@ fn containerDecl(
                     (@as(u32, @boolToInt(have_value)) << 31);
 
                 if (have_value) {
+                    if (arg_inst == .none) {
+                        return astgen.failNodeNotes(
+                            node,
+                            "explicitly valued enum missing integer tag type",
+                            .{},
+                            &[_]u32{
+                                try astgen.errNoteNode(
+                                    member.ast.value_expr,
+                                    "tag value specified here",
+                                    .{},
+                                ),
+                            },
+                        );
+                    }
                     const tag_value_inst = try expr(&block_scope, &block_scope.base, .{ .ty = arg_inst }, member.ast.value_expr);
                     fields_data.appendAssumeCapacity(@enumToInt(tag_value_inst));
                 }
@@ -3232,12 +3482,15 @@ fn containerDecl(
             }
             {
                 const empty_slot_count = 32 - (field_index % 32);
-                cur_bit_bag >>= @intCast(u5, empty_slot_count);
+                if (empty_slot_count < 32) {
+                    cur_bit_bag >>= @intCast(u5, empty_slot_count);
+                }
             }
-
-            if (wip_decls.decl_index != 0) {
+            {
                 const empty_slot_count = 16 - (wip_decls.decl_index % 16);
-                wip_decls.cur_bit_bag >>= @intCast(u5, empty_slot_count * 2);
+                if (empty_slot_count < 16) {
+                    wip_decls.cur_bit_bag >>= @intCast(u5, empty_slot_count * 2);
+                }
             }
 
             const decl_inst = try gz.addBlock(tag, node);
@@ -4789,7 +5042,7 @@ fn switchExpr(
                     .prong_index = capture_index,
                 } },
             });
-            const capture_name = try astgen.identifierTokenString(payload_token);
+            const capture_name = try astgen.identifierTokenString(ident);
             capture_val_scope = .{
                 .parent = &case_scope.base,
                 .gen_zir = &case_scope,
src/Sema.zig
@@ -338,7 +338,9 @@ pub fn analyzeBody(
             .struct_decl_extern      => try sema.zirStructDecl(block, inst, .Extern),
             .enum_decl               => try sema.zirEnumDecl(block, inst, false),
             .enum_decl_nonexhaustive => try sema.zirEnumDecl(block, inst, true),
-            .union_decl              => try sema.zirUnionDecl(block, inst),
+            .union_decl              => try sema.zirUnionDecl(block, inst, .Auto),
+            .union_decl_packed       => try sema.zirUnionDecl(block, inst, .Packed),
+            .union_decl_extern       => try sema.zirUnionDecl(block, inst, .Extern),
             .opaque_decl             => try sema.zirOpaqueDecl(block, inst),
             .error_set_decl          => try sema.zirErrorSetDecl(block, inst),
 
@@ -980,7 +982,12 @@ fn zirEnumDecl(
     return sema.analyzeDeclVal(block, src, new_decl);
 }
 
-fn zirUnionDecl(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) InnerError!*Inst {
+fn zirUnionDecl(
+    sema: *Sema,
+    block: *Scope.Block,
+    inst: Zir.Inst.Index,
+    layout: std.builtin.TypeInfo.ContainerLayout,
+) InnerError!*Inst {
     const tracy = trace(@src());
     defer tracy.end();
 
src/Zir.zig
@@ -301,6 +301,10 @@ pub const Inst = struct {
         /// the field types and optional type tag expression.
         /// Uses the `pl_node` union field. Payload is `UnionDecl`.
         union_decl,
+        /// Same as `union_decl`, except has the `packed` layout.
+        union_decl_packed,
+        /// Same as `union_decl`, except has the `extern` layout.
+        union_decl_extern,
         /// An enum type definition. Contains references to ZIR instructions for
         /// the field value expressions and optional type tag expression.
         /// Uses the `pl_node` union field. Payload is `EnumDecl`.
@@ -988,6 +992,8 @@ pub const Inst = struct {
                 .struct_decl_packed,
                 .struct_decl_extern,
                 .union_decl,
+                .union_decl_packed,
+                .union_decl_extern,
                 .enum_decl,
                 .enum_decl_nonexhaustive,
                 .opaque_decl,
@@ -2022,19 +2028,37 @@ pub const Inst = struct {
     };
 
     /// Trailing:
-    /// 0. has_bits: u32 // for every 10 fields (+1)
-    ///    - first bit is special: set if and only if auto enum tag is enabled.
-    ///    - sets of 3 bits:
-    ///      0b00X: whether corresponding field has a type expression
-    ///      0b0X0: whether corresponding field has a align expression
-    ///      0bX00: whether corresponding field has a tag value expression
-    /// 1. field_name: u32 // for every field: null terminated string index
-    /// 2. opt_exprs // Ref for every field for which corresponding bit is set
-    ///    - interleaved. type if present, align if present, tag value if present.
+    /// 0. inst: Index // for every body_len
+    /// 1. has_bits: u32 // for every 8 fields
+    ///    - sets of 4 bits:
+    ///      0b000X: whether corresponding field has a type expression
+    ///      0b00X0: whether corresponding field has a align expression
+    ///      0b0X00: whether corresponding field has a tag value expression
+    ///      0bX000: unused(*)
+    ///    * the first unused bit (the unused bit of the first field) is used
+    ///      to indicate whether auto enum tag is enabled.
+    ///      0 = union(tag_type)
+    ///      1 = union(enum(tag_type))
+    /// 2. fields: { // for every fields_len
+    ///        field_name: u32, // null terminated string index
+    ///        field_type: Ref, // if corresponding bit is set
+    ///        align: Ref, // if corresponding bit is set
+    ///        tag_value: Ref, // if corresponding bit is set
+    ///    }
+    /// 3. decl_bits: u32 // for every 16 decls
+    ///    - sets of 2 bits:
+    ///      0b0X: whether corresponding decl is pub
+    ///      0bX0: whether corresponding decl is exported
+    /// 4. decl: { // for every decls_len
+    ///        name: u32, // null terminated string index
+    ///        value: Index,
+    ///    }
     pub const UnionDecl = struct {
         /// Can be `Ref.none`.
         tag_type: Ref,
+        body_len: u32,
         fields_len: u32,
+        decls_len: u32,
     };
 
     /// Trailing: field_name: u32 // for every field: null terminated string index
@@ -2339,7 +2363,6 @@ const Writer = struct {
             .slice_start,
             .slice_end,
             .slice_sentinel,
-            .union_decl,
             .struct_init,
             .struct_init_anon,
             .array_init,
@@ -2452,6 +2475,11 @@ const Writer = struct {
             .struct_decl_extern,
             => try self.writeStructDecl(stream, inst),
 
+            .union_decl,
+            .union_decl_packed,
+            .union_decl_extern,
+            => try self.writeUnionDecl(stream, inst),
+
             .enum_decl,
             .enum_decl_nonexhaustive,
             => try self.writeEnumDecl(stream, inst),
@@ -2884,6 +2912,105 @@ const Writer = struct {
         try self.writeSrc(stream, inst_data.src());
     }
 
+    fn writeUnionDecl(self: *Writer, stream: anytype, inst: Inst.Index) !void {
+        const inst_data = self.code.instructions.items(.data)[inst].pl_node;
+        const extra = self.code.extraData(Inst.UnionDecl, inst_data.payload_index);
+        const body = self.code.extra[extra.end..][0..extra.data.body_len];
+        const fields_len = extra.data.fields_len;
+        const decls_len = extra.data.decls_len;
+        const tag_type_ref = extra.data.tag_type;
+
+        assert(fields_len != 0);
+        var first_has_auto_enum: ?bool = null;
+
+        if (tag_type_ref != .none) {
+            try self.writeInstRef(stream, tag_type_ref);
+            try stream.writeAll(", ");
+        }
+
+        var extra_index: usize = undefined;
+
+        try stream.writeAll("{\n");
+        self.indent += 2;
+        try self.writeBody(stream, body);
+
+        try stream.writeByteNTimes(' ', self.indent - 2);
+        try stream.writeAll("}, {\n");
+
+        const bits_per_field = 4;
+        const fields_per_u32 = 32 / bits_per_field;
+        const bit_bags_count = std.math.divCeil(usize, fields_len, fields_per_u32) catch unreachable;
+        const body_end = extra.end + body.len;
+        extra_index = body_end + bit_bags_count;
+        var bit_bag_index: usize = body_end;
+        var cur_bit_bag: u32 = undefined;
+        var field_i: u32 = 0;
+        while (field_i < fields_len) : (field_i += 1) {
+            if (field_i % fields_per_u32 == 0) {
+                cur_bit_bag = self.code.extra[bit_bag_index];
+                bit_bag_index += 1;
+            }
+            const has_type = @truncate(u1, cur_bit_bag) != 0;
+            cur_bit_bag >>= 1;
+            const has_align = @truncate(u1, cur_bit_bag) != 0;
+            cur_bit_bag >>= 1;
+            const has_value = @truncate(u1, cur_bit_bag) != 0;
+            cur_bit_bag >>= 1;
+            const has_auto_enum = @truncate(u1, cur_bit_bag) != 0;
+            cur_bit_bag >>= 1;
+
+            if (first_has_auto_enum == null) {
+                first_has_auto_enum = has_auto_enum;
+            }
+
+            const field_name = self.code.nullTerminatedString(self.code.extra[extra_index]);
+            extra_index += 1;
+            try stream.writeByteNTimes(' ', self.indent);
+            try stream.print("{}", .{std.zig.fmtId(field_name)});
+
+            if (has_type) {
+                const field_type = @intToEnum(Inst.Ref, self.code.extra[extra_index]);
+                extra_index += 1;
+
+                try stream.writeAll(": ");
+                try self.writeInstRef(stream, field_type);
+            }
+            if (has_align) {
+                const align_ref = @intToEnum(Inst.Ref, self.code.extra[extra_index]);
+                extra_index += 1;
+
+                try stream.writeAll(" align(");
+                try self.writeInstRef(stream, align_ref);
+                try stream.writeAll(")");
+            }
+            if (has_value) {
+                const default_ref = @intToEnum(Inst.Ref, self.code.extra[extra_index]);
+                extra_index += 1;
+
+                try stream.writeAll(" = ");
+                try self.writeInstRef(stream, default_ref);
+            }
+            try stream.writeAll(",\n");
+        }
+
+        self.indent -= 2;
+        try stream.writeByteNTimes(' ', self.indent);
+        try stream.writeAll("}, {");
+        if (decls_len == 0) {
+            try stream.writeAll("}");
+        } else {
+            try stream.writeAll("\n");
+            self.indent += 2;
+            try self.writeDecls(stream, decls_len, extra_index);
+            self.indent -= 2;
+            try stream.writeByteNTimes(' ', self.indent);
+            try stream.writeAll("}");
+        }
+        try self.writeFlag(stream, ", autoenum", first_has_auto_enum.?);
+        try stream.writeAll(") ");
+        try self.writeSrc(stream, inst_data.src());
+    }
+
     fn writeDecls(self: *Writer, stream: anytype, decls_len: u32, extra_start: usize) !void {
         const parent_decl_node = self.parent_decl_node;
         const bit_bags_count = std.math.divCeil(usize, decls_len, 16) catch unreachable;
@@ -2930,10 +3057,10 @@ const Writer = struct {
         const body = self.code.extra[extra.end..][0..extra.data.body_len];
         const fields_len = extra.data.fields_len;
         const decls_len = extra.data.decls_len;
-        const tag_ty_ref = extra.data.tag_type;
+        const tag_type_ref = extra.data.tag_type;
 
-        if (tag_ty_ref != .none) {
-            try self.writeInstRef(stream, tag_ty_ref);
+        if (tag_type_ref != .none) {
+            try self.writeInstRef(stream, tag_type_ref);
             try stream.writeAll(", ");
         }