Commit 55e86b724a

Andrew Kelley <andrew@ziglang.org>
2021-04-30 01:57:13
AstGen: implement comptime struct fields
1 parent fb4cb43
src/AstGen.zig
@@ -3220,7 +3220,9 @@ fn structDeclInner(
     var fields_data = ArrayListUnmanaged(u32){};
     defer fields_data.deinit(gpa);
 
-    // We only need this if there are greater than 16 fields.
+    const bits_per_field = 4;
+    const fields_per_u32 = 32 / bits_per_field;
+    // We only need this if there are greater than fields_per_u32 fields.
     var bit_bag = ArrayListUnmanaged(u32){};
     defer bit_bag.deinit(gpa);
 
@@ -3307,13 +3309,10 @@ fn structDeclInner(
             },
             else => unreachable,
         };
-        if (field_index % 16 == 0 and field_index != 0) {
+        if (field_index % fields_per_u32 == 0 and field_index != 0) {
             try bit_bag.append(gpa, cur_bit_bag);
             cur_bit_bag = 0;
         }
-        if (member.comptime_token) |comptime_token| {
-            return astgen.failTok(comptime_token, "TODO implement comptime struct fields", .{});
-        }
         try fields_data.ensureUnusedCapacity(gpa, 4);
 
         const field_name = try gz.identAsString(member.ast.name_token);
@@ -3324,9 +3323,13 @@ fn structDeclInner(
 
         const have_align = member.ast.align_expr != 0;
         const have_value = member.ast.value_expr != 0;
-        cur_bit_bag = (cur_bit_bag >> 2) |
-            (@as(u32, @boolToInt(have_align)) << 30) |
-            (@as(u32, @boolToInt(have_value)) << 31);
+        const is_comptime = member.comptime_token != null;
+        const unused = false;
+        cur_bit_bag = (cur_bit_bag >> bits_per_field) |
+            (@as(u32, @boolToInt(have_align)) << 28) |
+            (@as(u32, @boolToInt(have_value)) << 29) |
+            (@as(u32, @boolToInt(is_comptime)) << 30) |
+            (@as(u32, @boolToInt(unused)) << 31);
 
         if (have_align) {
             const align_inst = try expr(&block_scope, &block_scope.base, align_rl, member.ast.align_expr);
@@ -3335,14 +3338,16 @@ fn structDeclInner(
         if (have_value) {
             const default_inst = try expr(&block_scope, &block_scope.base, .{ .ty = field_type }, member.ast.value_expr);
             fields_data.appendAssumeCapacity(@enumToInt(default_inst));
+        } else if (member.comptime_token) |comptime_token| {
+            return astgen.failTok(comptime_token, "comptime field without default initialization value", .{});
         }
 
         field_index += 1;
     }
     {
-        const empty_slot_count = 16 - (field_index % 16);
-        if (empty_slot_count < 16) {
-            cur_bit_bag >>= @intCast(u5, empty_slot_count * 2);
+        const empty_slot_count = fields_per_u32 - (field_index % fields_per_u32);
+        if (empty_slot_count < fields_per_u32) {
+            cur_bit_bag >>= @intCast(u5, empty_slot_count * bits_per_field);
         }
     }
     {
src/Module.zig
@@ -471,6 +471,7 @@ pub const Struct = struct {
         abi_align: Value,
         /// Uses `unreachable_value` to indicate no default.
         default_val: Value,
+        is_comptime: bool,
     };
 
     pub fn getFullyQualifiedName(s: *Struct, gpa: *Allocator) ![]u8 {
src/Sema.zig
@@ -719,14 +719,16 @@ pub fn zirStructDecl(
         sema.branch_count = struct_sema.branch_count;
         sema.branch_quota = struct_sema.branch_quota;
     }
-    const bit_bags_count = std.math.divCeil(usize, fields_len, 16) catch unreachable;
+    const bits_per_field = 4;
+    const fields_per_u32 = 32 / bits_per_field;
+    const bit_bags_count = std.math.divCeil(usize, fields_len, fields_per_u32) catch unreachable;
     const body_end = extra_index + body.len;
     extra_index += bit_bags_count;
     var bit_bag_index: usize = body_end;
     var cur_bit_bag: u32 = undefined;
     var field_i: u32 = 0;
     while (field_i < fields_len) : (field_i += 1) {
-        if (field_i % 16 == 0) {
+        if (field_i % fields_per_u32 == 0) {
             cur_bit_bag = sema.code.extra[bit_bag_index];
             bit_bag_index += 1;
         }
@@ -734,6 +736,12 @@ pub fn zirStructDecl(
         cur_bit_bag >>= 1;
         const has_default = @truncate(u1, cur_bit_bag) != 0;
         cur_bit_bag >>= 1;
+        const is_comptime = @truncate(u1, cur_bit_bag) != 0;
+        cur_bit_bag >>= 1;
+        const unused = @truncate(u1, cur_bit_bag) != 0;
+        cur_bit_bag >>= 1;
+
+        _ = unused;
 
         const field_name_zir = sema.code.nullTerminatedString(sema.code.extra[extra_index]);
         extra_index += 1;
@@ -753,6 +761,7 @@ pub fn zirStructDecl(
             .ty = field_ty,
             .abi_align = Value.initTag(.abi_align_default),
             .default_val = Value.initTag(.unreachable_value),
+            .is_comptime = is_comptime,
         };
 
         if (has_align) {
src/Zir.zig
@@ -2387,13 +2387,16 @@ pub const Inst = struct {
     ///        link_section: Ref, // if corresponding bit is set
     ///    }
     /// 2. inst: Index // for every body_len
-    /// 3. has_bits: u32 // for every 16 fields
-    ///    - sets of 2 bits:
-    ///      0b0X: whether corresponding field has an align expression
-    ///      0bX0: whether corresponding field has a default expression
+    /// 3. flags: u32 // for every 8 fields
+    ///    - sets of 4 bits:
+    ///      0b000X: whether corresponding field has an align expression
+    ///      0b00X0: whether corresponding field has a default expression
+    ///      0b0X00: whether corresponding field is comptime
+    ///      0bX000: unused
     /// 4. fields: { // for every fields_len
     ///        field_name: u32,
     ///        field_type: Ref,
+    ///        - if none, means `anytype`.
     ///        align: Ref, // if corresponding bit is set
     ///        default_value: Ref, // if corresponding bit is set
     ///    }
@@ -3394,14 +3397,16 @@ const Writer = struct {
                 try stream.writeAll("}, {\n");
             }
 
-            const bit_bags_count = std.math.divCeil(usize, fields_len, 16) catch unreachable;
+            const bits_per_field = 4;
+            const fields_per_u32 = 32 / bits_per_field;
+            const bit_bags_count = std.math.divCeil(usize, fields_len, fields_per_u32) catch unreachable;
             const body_end = extra_index;
             extra_index += bit_bags_count;
             var bit_bag_index: usize = body_end;
             var cur_bit_bag: u32 = undefined;
             var field_i: u32 = 0;
             while (field_i < fields_len) : (field_i += 1) {
-                if (field_i % 16 == 0) {
+                if (field_i % fields_per_u32 == 0) {
                     cur_bit_bag = self.code.extra[bit_bag_index];
                     bit_bag_index += 1;
                 }
@@ -3409,6 +3414,12 @@ const Writer = struct {
                 cur_bit_bag >>= 1;
                 const has_default = @truncate(u1, cur_bit_bag) != 0;
                 cur_bit_bag >>= 1;
+                const is_comptime = @truncate(u1, cur_bit_bag) != 0;
+                cur_bit_bag >>= 1;
+                const unused = @truncate(u1, cur_bit_bag) != 0;
+                cur_bit_bag >>= 1;
+
+                _ = unused;
 
                 const field_name = self.code.nullTerminatedString(self.code.extra[extra_index]);
                 extra_index += 1;
@@ -3416,6 +3427,7 @@ const Writer = struct {
                 extra_index += 1;
 
                 try stream.writeByteNTimes(' ', self.indent);
+                try self.writeFlag(stream, "comptime ", is_comptime);
                 try stream.print("{}: ", .{std.zig.fmtId(field_name)});
                 try self.writeInstRef(stream, field_type);
 
BRANCH_TODO
@@ -52,6 +52,10 @@
    - not sure why this happened, it's stage1 code??
    - search the behavior test diff for "TODO"
 
+ * memory efficiency: add another representation for structs which use
+   natural alignment for fields and do not have any comptime fields. this
+   will save 16 bytes per struct field in the compilation.
+
 fn getAnonTypeName(mod: *Module, scope: *Scope, base_token: std.zig.ast.TokenIndex) ![]u8 {
     // TODO add namespaces, generic function signatrues
     const tree = scope.tree();