Commit 0e50a0c1e5

Andrew Kelley <andrew@ziglang.org>
2021-04-13 03:40:47
stage2: implement non-trivial enums
1 parent bcfebb4
Changed files (3)
src/AstGen.zig
@@ -1932,7 +1932,7 @@ fn containerDecl(
     // ZIR for all the field types, alignments, and default value expressions.
 
     const arg_inst: zir.Inst.Ref = if (container_decl.ast.arg != 0)
-        try comptimeExpr(gz, scope, .none, container_decl.ast.arg)
+        try comptimeExpr(gz, scope, .{ .ty = .type_type }, container_decl.ast.arg)
     else
         .none;
 
@@ -2006,6 +2006,9 @@ fn containerDecl(
                     }
                     total_fields += 1;
                     if (member.ast.value_expr != 0) {
+                        if (arg_inst == .none) {
+                            return mod.failNode(scope, member.ast.value_expr, "value assigned to enum tag with inferred tag type", .{});
+                        }
                         values += 1;
                     }
                 }
@@ -2118,7 +2121,91 @@ fn containerDecl(
             // In this case we must generate ZIR code for the tag values, similar to
             // how structs are handled above. The new anonymous Decl will be created in
             // Sema, not AstGen.
-            return mod.failNode(scope, node, "TODO AstGen for enum decl with decls or explicitly provided field values", .{});
+            const tag: zir.Inst.Tag = if (counts.nonexhaustive_node == 0)
+                .enum_decl
+            else
+                .enum_decl_nonexhaustive;
+            if (counts.total_fields == 0) {
+                return gz.addPlNode(tag, node, zir.Inst.EnumDecl{
+                    .tag_type = arg_inst,
+                    .fields_len = 0,
+                    .body_len = 0,
+                });
+            }
+
+            // The enum_decl instruction introduces a scope in which the decls of the enum
+            // are in scope, so that tag values can refer to decls within the enum itself.
+            var block_scope: GenZir = .{
+                .parent = scope,
+                .astgen = astgen,
+                .force_comptime = true,
+            };
+            defer block_scope.instructions.deinit(gpa);
+
+            var fields_data = ArrayListUnmanaged(u32){};
+            defer fields_data.deinit(gpa);
+
+            try fields_data.ensureCapacity(gpa, counts.total_fields + counts.values);
+
+            // We only need this if there are greater than 32 fields.
+            var bit_bag = ArrayListUnmanaged(u32){};
+            defer bit_bag.deinit(gpa);
+
+            var cur_bit_bag: u32 = 0;
+            var field_index: usize = 0;
+            for (container_decl.ast.members) |member_node| {
+                if (member_node == counts.nonexhaustive_node)
+                    continue;
+                const member = switch (node_tags[member_node]) {
+                    .container_field_init => tree.containerFieldInit(member_node),
+                    .container_field_align => tree.containerFieldAlign(member_node),
+                    .container_field => tree.containerField(member_node),
+                    else => continue,
+                };
+                if (field_index % 32 == 0 and field_index != 0) {
+                    try bit_bag.append(gpa, cur_bit_bag);
+                    cur_bit_bag = 0;
+                }
+                assert(member.comptime_token == null);
+                assert(member.ast.type_expr == 0);
+                assert(member.ast.align_expr == 0);
+
+                const field_name = try gz.identAsString(member.ast.name_token);
+                fields_data.appendAssumeCapacity(field_name);
+
+                const have_value = member.ast.value_expr != 0;
+                cur_bit_bag = (cur_bit_bag >> 1) |
+                    (@as(u32, @boolToInt(have_value)) << 31);
+
+                if (have_value) {
+                    const tag_value_inst = try expr(&block_scope, &block_scope.base, .{ .ty = arg_inst }, member.ast.value_expr);
+                    fields_data.appendAssumeCapacity(@enumToInt(tag_value_inst));
+                }
+
+                field_index += 1;
+            }
+            const empty_slot_count = 32 - (field_index % 32);
+            cur_bit_bag >>= @intCast(u5, empty_slot_count);
+
+            const decl_inst = try gz.addBlock(tag, node);
+            try gz.instructions.append(gpa, decl_inst);
+            _ = try block_scope.addBreak(.break_inline, decl_inst, .void_value);
+
+            try astgen.extra.ensureCapacity(gpa, astgen.extra.items.len +
+                @typeInfo(zir.Inst.EnumDecl).Struct.fields.len +
+                bit_bag.items.len + 1 + fields_data.items.len +
+                block_scope.instructions.items.len);
+            const zir_datas = astgen.instructions.items(.data);
+            zir_datas[decl_inst].pl_node.payload_index = astgen.addExtraAssumeCapacity(zir.Inst.EnumDecl{
+                .tag_type = arg_inst,
+                .body_len = @intCast(u32, block_scope.instructions.items.len),
+                .fields_len = @intCast(u32, field_index),
+            });
+            astgen.extra.appendSliceAssumeCapacity(block_scope.instructions.items);
+            astgen.extra.appendSliceAssumeCapacity(bit_bag.items); // Likely empty.
+            astgen.extra.appendAssumeCapacity(cur_bit_bag);
+            astgen.extra.appendSliceAssumeCapacity(fields_data.items);
+            return rvalue(gz, scope, rl, astgen.indexToRef(decl_inst), node);
         },
         .keyword_opaque => {
             const result = try gz.addNode(.opaque_decl, node);
src/Sema.zig
@@ -542,7 +542,7 @@ fn zirStructDecl(
     const body = sema.code.extra[extra.end..][0..extra.data.body_len];
     const fields_len = extra.data.fields_len;
 
-    var new_decl_arena = std.heap.ArenaAllocator.init(sema.gpa);
+    var new_decl_arena = std.heap.ArenaAllocator.init(gpa);
 
     const struct_obj = try new_decl_arena.allocator.create(Module.Struct);
     const struct_ty = try Type.Tag.@"struct".create(&new_decl_arena.allocator, struct_obj);
@@ -602,7 +602,7 @@ fn zirStructDecl(
         // should be the struct itself. Thus we need a new Sema.
         var struct_sema: Sema = .{
             .mod = sema.mod,
-            .gpa = sema.mod.gpa,
+            .gpa = gpa,
             .arena = &new_decl_arena.allocator,
             .code = sema.code,
             .inst_map = sema.inst_map,
@@ -632,7 +632,7 @@ fn zirStructDecl(
     }
     const bit_bags_count = std.math.divCeil(usize, fields_len, 16) catch unreachable;
     const body_end = extra.end + body.len;
-    var field_index: usize = body_end + bit_bags_count;
+    var extra_index: usize = body_end + bit_bags_count;
     var bit_bag_index: usize = body_end;
     var cur_bit_bag: u32 = undefined;
     var field_i: u32 = 0;
@@ -646,10 +646,10 @@ fn zirStructDecl(
         const has_default = @truncate(u1, cur_bit_bag) != 0;
         cur_bit_bag >>= 1;
 
-        const field_name_zir = sema.code.nullTerminatedString(sema.code.extra[field_index]);
-        field_index += 1;
-        const field_type_ref = @intToEnum(zir.Inst.Ref, sema.code.extra[field_index]);
-        field_index += 1;
+        const field_name_zir = sema.code.nullTerminatedString(sema.code.extra[extra_index]);
+        extra_index += 1;
+        const field_type_ref = @intToEnum(zir.Inst.Ref, sema.code.extra[extra_index]);
+        extra_index += 1;
 
         // This string needs to outlive the ZIR code.
         const field_name = try new_decl_arena.allocator.dupe(u8, field_name_zir);
@@ -667,16 +667,16 @@ fn zirStructDecl(
         };
 
         if (has_align) {
-            const align_ref = @intToEnum(zir.Inst.Ref, sema.code.extra[field_index]);
-            field_index += 1;
+            const align_ref = @intToEnum(zir.Inst.Ref, sema.code.extra[extra_index]);
+            extra_index += 1;
             // TODO: if we need to report an error here, use a source location
             // that points to this alignment expression rather than the struct.
             // But only resolve the source location if we need to emit a compile error.
             gop.entry.value.abi_align = (try sema.resolveInstConst(block, src, align_ref)).val;
         }
         if (has_default) {
-            const default_ref = @intToEnum(zir.Inst.Ref, sema.code.extra[field_index]);
-            field_index += 1;
+            const default_ref = @intToEnum(zir.Inst.Ref, sema.code.extra[extra_index]);
+            extra_index += 1;
             // TODO: if we need to report an error here, use a source location
             // that points to this default value expression rather than the struct.
             // But only resolve the source location if we need to emit a compile error.
@@ -696,11 +696,164 @@ fn zirEnumDecl(
     const tracy = trace(@src());
     defer tracy.end();
 
+    const gpa = sema.gpa;
     const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
     const src = inst_data.src();
-    const extra = sema.code.extraData(zir.Inst.Block, inst_data.payload_index);
+    const extra = sema.code.extraData(zir.Inst.EnumDecl, inst_data.payload_index);
+    const body = sema.code.extra[extra.end..][0..extra.data.body_len];
+    const fields_len = extra.data.fields_len;
+
+    var new_decl_arena = std.heap.ArenaAllocator.init(gpa);
+
+    const tag_ty = blk: {
+        if (extra.data.tag_type != .none) {
+            // TODO better source location
+            // TODO (needs AstGen fix too) move this eval to the block so it gets allocated
+            // in the new decl arena.
+            break :blk try sema.resolveType(block, src, extra.data.tag_type);
+        }
+        const bits = std.math.log2_int_ceil(usize, fields_len);
+        break :blk try Type.Tag.int_unsigned.create(&new_decl_arena.allocator, bits);
+    };
+
+    const enum_obj = try new_decl_arena.allocator.create(Module.EnumFull);
+    const enum_ty_payload = try gpa.create(Type.Payload.EnumFull);
+    enum_ty_payload.* = .{
+        .base = .{ .tag = if (nonexhaustive) .enum_nonexhaustive else .enum_full },
+        .data = enum_obj,
+    };
+    const enum_ty = Type.initPayload(&enum_ty_payload.base);
+    const enum_val = try Value.Tag.ty.create(&new_decl_arena.allocator, enum_ty);
+    const new_decl = try sema.mod.createAnonymousDecl(&block.base, &new_decl_arena, .{
+        .ty = Type.initTag(.type),
+        .val = enum_val,
+    });
+    enum_obj.* = .{
+        .owner_decl = sema.owner_decl,
+        .tag_ty = tag_ty,
+        .fields = .{},
+        .values = .{},
+        .node_offset = inst_data.src_node,
+        .namespace = .{
+            .parent = sema.owner_decl.namespace,
+            .parent_name_hash = new_decl.fullyQualifiedNameHash(),
+            .ty = enum_ty,
+            .file_scope = block.getFileScope(),
+        },
+    };
+
+    {
+        const ast = std.zig.ast;
+        const node = sema.owner_decl.relativeToNodeIndex(inst_data.src_node);
+        const tree: *const ast.Tree = &enum_obj.namespace.file_scope.tree;
+        const node_tags = tree.nodes.items(.tag);
+        var buf: [2]ast.Node.Index = undefined;
+        const members: []const ast.Node.Index = switch (node_tags[node]) {
+            .container_decl,
+            .container_decl_trailing,
+            => tree.containerDecl(node).ast.members,
+
+            .container_decl_two,
+            .container_decl_two_trailing,
+            => tree.containerDeclTwo(&buf, node).ast.members,
+
+            .container_decl_arg,
+            .container_decl_arg_trailing,
+            => tree.containerDeclArg(node).ast.members,
+
+            .root => tree.rootDecls(),
+            else => unreachable,
+        };
+        try sema.mod.analyzeNamespace(&enum_obj.namespace, members);
+    }
+
+    if (fields_len == 0) {
+        assert(body.len == 0);
+        return sema.analyzeDeclVal(block, src, new_decl);
+    }
+
+    const bit_bags_count = std.math.divCeil(usize, fields_len, 32) catch unreachable;
+    const body_end = extra.end + body.len;
+
+    try enum_obj.fields.ensureCapacity(&new_decl_arena.allocator, fields_len);
+    const any_values = for (sema.code.extra[body_end..][0..bit_bags_count]) |bag| {
+        if (bag != 0) break true;
+    } else false;
+    if (any_values) {
+        try enum_obj.values.ensureCapacity(&new_decl_arena.allocator, fields_len);
+    }
+
+    {
+        // We create a block for the field type instructions because they
+        // may need to reference Decls from inside the enum namespace.
+        // Within the field type, default value, and alignment expressions, the "owner decl"
+        // should be the enum itself. Thus we need a new Sema.
+        var enum_sema: Sema = .{
+            .mod = sema.mod,
+            .gpa = gpa,
+            .arena = &new_decl_arena.allocator,
+            .code = sema.code,
+            .inst_map = sema.inst_map,
+            .owner_decl = new_decl,
+            .namespace = &enum_obj.namespace,
+            .owner_func = null,
+            .func = null,
+            .param_inst_list = &.{},
+            .branch_quota = sema.branch_quota,
+            .branch_count = sema.branch_count,
+        };
+
+        var enum_block: Scope.Block = .{
+            .parent = null,
+            .sema = &enum_sema,
+            .src_decl = new_decl,
+            .instructions = .{},
+            .inlining = null,
+            .is_comptime = true,
+        };
+        defer assert(enum_block.instructions.items.len == 0); // should all be comptime instructions
+
+        _ = try enum_sema.analyzeBody(&enum_block, body);
+
+        sema.branch_count = enum_sema.branch_count;
+        sema.branch_quota = enum_sema.branch_quota;
+    }
+    var extra_index: usize = body_end + bit_bags_count;
+    var bit_bag_index: usize = body_end;
+    var cur_bit_bag: u32 = undefined;
+    var field_i: u32 = 0;
+    while (field_i < fields_len) : (field_i += 1) {
+        if (field_i % 32 == 0) {
+            cur_bit_bag = sema.code.extra[bit_bag_index];
+            bit_bag_index += 1;
+        }
+        const has_tag_value = @truncate(u1, cur_bit_bag) != 0;
+        cur_bit_bag >>= 1;
+
+        const field_name_zir = sema.code.nullTerminatedString(sema.code.extra[extra_index]);
+        extra_index += 1;
 
-    return sema.mod.fail(&block.base, sema.src, "TODO implement zirEnumDecl", .{});
+        // This string needs to outlive the ZIR code.
+        const field_name = try new_decl_arena.allocator.dupe(u8, field_name_zir);
+
+        const gop = enum_obj.fields.getOrPutAssumeCapacity(field_name);
+        assert(!gop.found_existing);
+
+        if (has_tag_value) {
+            const tag_val_ref = @intToEnum(zir.Inst.Ref, sema.code.extra[extra_index]);
+            extra_index += 1;
+            // TODO: if we need to report an error here, use a source location
+            // that points to this default value expression rather than the struct.
+            // But only resolve the source location if we need to emit a compile error.
+            const tag_val = (try sema.resolveInstConst(block, src, tag_val_ref)).val;
+            enum_obj.values.putAssumeCapacityNoClobber(tag_val, {});
+        } else if (any_values) {
+            const tag_val = try Value.Tag.int_u64.create(&new_decl_arena.allocator, field_i);
+            enum_obj.values.putAssumeCapacityNoClobber(tag_val, {});
+        }
+    }
+
+    return sema.analyzeDeclVal(block, src, new_decl);
 }
 
 fn zirUnionDecl(sema: *Sema, block: *Scope.Block, inst: zir.Inst.Index) InnerError!*Inst {
src/zir.zig
@@ -1532,13 +1532,17 @@ pub const Inst = struct {
     };
 
     /// Trailing:
-    /// 0. has_bits: u32 // for every 32 fields
+    /// 0. inst: Index // for every body_len
+    /// 1. has_bits: u32 // for every 32 fields
     ///    - the bit is whether corresponding field has an value expression
-    /// 1. field_name: u32 // for every field: null terminated string index
-    /// 2. value: Ref // for every field for which corresponding bit is set
+    /// 2. fields: { // for every fields_len
+    ///        field_name: u32,
+    ///        value: Ref, // if corresponding bit is set
+    ///    }
     pub const EnumDecl = struct {
         /// Can be `Ref.none`.
         tag_type: Ref,
+        body_len: u32,
         fields_len: u32,
     };
 
@@ -1704,8 +1708,6 @@ const Writer = struct {
             .slice_end,
             .slice_sentinel,
             .union_decl,
-            .enum_decl,
-            .enum_decl_nonexhaustive,
             .struct_init,
             .field_type,
             => try self.writePlNode(stream, inst),
@@ -1761,6 +1763,10 @@ const Writer = struct {
             .struct_decl_extern,
             => try self.writeStructDecl(stream, inst),
 
+            .enum_decl,
+            .enum_decl_nonexhaustive,
+            => try self.writeEnumDecl(stream, inst),
+
             .switch_block => try self.writePlNodeSwitchBr(stream, inst, .none),
             .switch_block_else => try self.writePlNodeSwitchBr(stream, inst, .@"else"),
             .switch_block_under => try self.writePlNodeSwitchBr(stream, inst, .under),
@@ -2031,7 +2037,7 @@ const Writer = struct {
 
         const bit_bags_count = std.math.divCeil(usize, fields_len, 16) catch unreachable;
         const body_end = extra.end + body.len;
-        var field_index: usize = body_end + bit_bags_count;
+        var extra_index: usize = body_end + bit_bags_count;
         var bit_bag_index: usize = body_end;
         var cur_bit_bag: u32 = undefined;
         var field_i: u32 = 0;
@@ -2045,26 +2051,26 @@ const Writer = struct {
             const has_default = @truncate(u1, cur_bit_bag) != 0;
             cur_bit_bag >>= 1;
 
-            const field_name = self.code.nullTerminatedString(self.code.extra[field_index]);
-            field_index += 1;
-            const field_type = @intToEnum(Inst.Ref, self.code.extra[field_index]);
-            field_index += 1;
+            const field_name = self.code.nullTerminatedString(self.code.extra[extra_index]);
+            extra_index += 1;
+            const field_type = @intToEnum(Inst.Ref, self.code.extra[extra_index]);
+            extra_index += 1;
 
             try stream.writeByteNTimes(' ', self.indent);
             try stream.print("{}: ", .{std.zig.fmtId(field_name)});
             try self.writeInstRef(stream, field_type);
 
             if (has_align) {
-                const align_ref = @intToEnum(Inst.Ref, self.code.extra[field_index]);
-                field_index += 1;
+                const align_ref = @intToEnum(Inst.Ref, self.code.extra[extra_index]);
+                extra_index += 1;
 
                 try stream.writeAll(" align(");
                 try self.writeInstRef(stream, align_ref);
                 try stream.writeAll(")");
             }
             if (has_default) {
-                const default_ref = @intToEnum(Inst.Ref, self.code.extra[field_index]);
-                field_index += 1;
+                const default_ref = @intToEnum(Inst.Ref, self.code.extra[extra_index]);
+                extra_index += 1;
 
                 try stream.writeAll(" = ");
                 try self.writeInstRef(stream, default_ref);
@@ -2078,6 +2084,68 @@ const Writer = struct {
         try self.writeSrc(stream, inst_data.src());
     }
 
+    fn writeEnumDecl(self: *Writer, stream: anytype, inst: Inst.Index) !void {
+        const inst_data = self.code.instructions.items(.data)[inst].pl_node;
+        const extra = self.code.extraData(Inst.EnumDecl, inst_data.payload_index);
+        const body = self.code.extra[extra.end..][0..extra.data.body_len];
+        const fields_len = extra.data.fields_len;
+        const tag_ty_ref = extra.data.tag_type;
+
+        if (tag_ty_ref != .none) {
+            try self.writeInstRef(stream, tag_ty_ref);
+            try stream.writeAll(", ");
+        }
+
+        if (fields_len == 0) {
+            assert(body.len == 0);
+            try stream.writeAll("{}, {}) ");
+            try self.writeSrc(stream, inst_data.src());
+            return;
+        }
+
+        try stream.writeAll("{\n");
+        self.indent += 2;
+        try self.writeBody(stream, body);
+
+        try stream.writeByteNTimes(' ', self.indent - 2);
+        try stream.writeAll("}, {\n");
+
+        const bit_bags_count = std.math.divCeil(usize, fields_len, 32) catch unreachable;
+        const body_end = extra.end + body.len;
+        var extra_index: usize = body_end + bit_bags_count;
+        var bit_bag_index: usize = body_end;
+        var cur_bit_bag: u32 = undefined;
+        var field_i: u32 = 0;
+        while (field_i < fields_len) : (field_i += 1) {
+            if (field_i % 32 == 0) {
+                cur_bit_bag = self.code.extra[bit_bag_index];
+                bit_bag_index += 1;
+            }
+            const has_tag_value = @truncate(u1, cur_bit_bag) != 0;
+            cur_bit_bag >>= 1;
+
+            const field_name = self.code.nullTerminatedString(self.code.extra[extra_index]);
+            extra_index += 1;
+
+            try stream.writeByteNTimes(' ', self.indent);
+            try stream.print("{}", .{std.zig.fmtId(field_name)});
+
+            if (has_tag_value) {
+                const tag_value_ref = @intToEnum(Inst.Ref, self.code.extra[extra_index]);
+                extra_index += 1;
+
+                try stream.writeAll(" = ");
+                try self.writeInstRef(stream, tag_value_ref);
+            }
+            try stream.writeAll(",\n");
+        }
+
+        self.indent -= 2;
+        try stream.writeByteNTimes(' ', self.indent);
+        try stream.writeAll("}) ");
+        try self.writeSrc(stream, inst_data.src());
+    }
+
     fn writePlNodeSwitchBr(
         self: *Writer,
         stream: anytype,