Commit 2c639d6570

GethDW <gethwilliams@googlemail.com>
2023-03-23 11:00:10
std.MultiArrayList: add support for tagged unions.
1 parent 88dfb13
Changed files (3)
lib
src
translate_c
lib/std/zig/Parse.zig
@@ -41,13 +41,13 @@ fn listToSpan(p: *Parse, list: []const Node.Index) !Node.SubRange {
     };
 }
 
-fn addNode(p: *Parse, elem: Ast.NodeList.Elem) Allocator.Error!Node.Index {
+fn addNode(p: *Parse, elem: Ast.Node) Allocator.Error!Node.Index {
     const result = @intCast(Node.Index, p.nodes.len);
     try p.nodes.append(p.gpa, elem);
     return result;
 }
 
-fn setNode(p: *Parse, i: usize, elem: Ast.NodeList.Elem) Node.Index {
+fn setNode(p: *Parse, i: usize, elem: Ast.Node) Node.Index {
     p.nodes.set(i, elem);
     return @intCast(Node.Index, i);
 }
lib/std/multi_array_list.zig
@@ -1,4 +1,4 @@
-const std = @import("std.zig");
+const std = @import("std");
 const builtin = @import("builtin");
 const assert = std.debug.assert;
 const meta = std.meta;
@@ -6,24 +6,57 @@ const mem = std.mem;
 const Allocator = mem.Allocator;
 const testing = std.testing;
 
-/// A MultiArrayList stores a list of a struct type.
+/// A MultiArrayList stores a list of a struct or tagged union type.
 /// Instead of storing a single list of items, MultiArrayList
-/// stores separate lists for each field of the struct.
-/// This allows for memory savings if the struct has padding,
-/// and also improves cache usage if only some fields are needed
-/// for a computation.  The primary API for accessing fields is
+/// stores separate lists for each field of the struct or
+/// lists of tags and bare unions.
+/// This allows for memory savings if the struct or union has padding,
+/// and also improves cache usage if only some fields or or just tags
+/// are needed for a computation.  The primary API for accessing fields is
 /// the `slice()` function, which computes the start pointers
 /// for the array of each field.  From the slice you can call
 /// `.items(.<field_name>)` to obtain a slice of field values.
-pub fn MultiArrayList(comptime S: type) type {
+/// For unions you can call `.items(.tags)` or `.items(.data)`.
+pub fn MultiArrayList(comptime T: type) type {
     return struct {
-        bytes: [*]align(@alignOf(S)) u8 = undefined,
+        bytes: [*]align(@alignOf(T)) u8 = undefined,
         len: usize = 0,
         capacity: usize = 0,
 
-        pub const Elem = S;
+        const Elem = switch (@typeInfo(T)) {
+            .Struct => T,
+            .Union => |u| struct {
+                pub const Bare =
+                    @Type(.{ .Union = .{
+                    .layout = u.layout,
+                    .tag_type = null,
+                    .fields = u.fields,
+                    .decls = &.{},
+                } });
+                pub const Tag =
+                    u.tag_type orelse @compileError("MultiArrayList does not support untagged unions");
+                tags: Tag,
+                data: Bare,
+
+                pub fn fromT(outer: T) @This() {
+                    const tag = meta.activeTag(outer);
+                    return .{
+                        .tags = tag,
+                        .data = switch (tag) {
+                            inline else => |t| @unionInit(Bare, @tagName(t), @field(outer, @tagName(t))),
+                        },
+                    };
+                }
+                pub fn toT(tag: Tag, bare: Bare) T {
+                    return switch (tag) {
+                        inline else => |t| @unionInit(T, @tagName(t), @field(bare, @tagName(t))),
+                    };
+                }
+            },
+            else => @compileError("MultiArrayList only supports structs and tagged unions"),
+        };
 
-        pub const Field = meta.FieldEnum(S);
+        pub const Field = meta.FieldEnum(Elem);
 
         /// A MultiArrayList.Slice contains cached start pointers for each field in the list.
         /// These pointers are not normally stored to reduce the size of the list in memory.
@@ -49,18 +82,27 @@ pub fn MultiArrayList(comptime S: type) type {
                 return casted_ptr[0..self.len];
             }
 
-            pub fn set(self: Slice, index: usize, elem: S) void {
-                inline for (fields) |field_info| {
-                    self.items(@field(Field, field_info.name))[index] = @field(elem, field_info.name);
+            pub fn set(self: *Slice, index: usize, elem: T) void {
+                const e = switch (@typeInfo(T)) {
+                    .Struct => elem,
+                    .Union => Elem.fromT(elem),
+                    else => unreachable,
+                };
+                inline for (fields, 0..) |field_info, i| {
+                    self.items(@intToEnum(Field, i))[index] = @field(e, field_info.name);
                 }
             }
 
-            pub fn get(self: Slice, index: usize) S {
-                var elem: S = undefined;
-                inline for (fields) |field_info| {
-                    @field(elem, field_info.name) = self.items(@field(Field, field_info.name))[index];
+            pub fn get(self: Slice, index: usize) T {
+                var result: Elem = undefined;
+                inline for (fields, 0..) |field_info, i| {
+                    @field(result, field_info.name) = self.items(@intToEnum(Field, i))[index];
                 }
-                return elem;
+                return switch (@typeInfo(T)) {
+                    .Struct => result,
+                    .Union => Elem.toT(result.tags, result.data),
+                    else => unreachable,
+                };
             }
 
             pub fn toMultiArrayList(self: Slice) Self {
@@ -68,8 +110,8 @@ pub fn MultiArrayList(comptime S: type) type {
                     return .{};
                 }
                 const unaligned_ptr = self.ptrs[sizes.fields[0]];
-                const aligned_ptr = @alignCast(@alignOf(S), unaligned_ptr);
-                const casted_ptr = @ptrCast([*]align(@alignOf(S)) u8, aligned_ptr);
+                const aligned_ptr = @alignCast(@alignOf(Elem), unaligned_ptr);
+                const casted_ptr = @ptrCast([*]align(@alignOf(Elem)) u8, aligned_ptr);
                 return .{
                     .bytes = casted_ptr,
                     .len = self.len,
@@ -85,7 +127,7 @@ pub fn MultiArrayList(comptime S: type) type {
 
             /// This function is used in the debugger pretty formatters in tools/ to fetch the
             /// child field order and entry type to facilitate fancy debug printing for this type.
-            fn dbHelper(self: *Slice, child: *S, field: *Field, entry: *Entry) void {
+            fn dbHelper(self: *Slice, child: *Elem, field: *Field, entry: *Entry) void {
                 _ = self;
                 _ = child;
                 _ = field;
@@ -95,8 +137,8 @@ pub fn MultiArrayList(comptime S: type) type {
 
         const Self = @This();
 
-        const fields = meta.fields(S);
-        /// `sizes.bytes` is an array of @sizeOf each S field. Sorted by alignment, descending.
+        const fields = meta.fields(Elem);
+        /// `sizes.bytes` is an array of @sizeOf each T field. Sorted by alignment, descending.
         /// `sizes.fields` is an array mapping from `sizes.bytes` array index to field index.
         const sizes = blk: {
             const Data = struct {
@@ -169,24 +211,25 @@ pub fn MultiArrayList(comptime S: type) type {
         }
 
         /// Overwrite one array element with new data.
-        pub fn set(self: *Self, index: usize, elem: S) void {
-            return self.slice().set(index, elem);
+        pub fn set(self: *Self, index: usize, elem: T) void {
+            var slices = self.slice();
+            slices.set(index, elem);
         }
 
         /// Obtain all the data for one array element.
-        pub fn get(self: Self, index: usize) S {
+        pub fn get(self: Self, index: usize) T {
             return self.slice().get(index);
         }
 
         /// Extend the list by 1 element. Allocates more memory as necessary.
-        pub fn append(self: *Self, gpa: Allocator, elem: S) !void {
+        pub fn append(self: *Self, gpa: Allocator, elem: T) !void {
             try self.ensureUnusedCapacity(gpa, 1);
             self.appendAssumeCapacity(elem);
         }
 
         /// Extend the list by 1 element, but asserting `self.capacity`
         /// is sufficient to hold an additional item.
-        pub fn appendAssumeCapacity(self: *Self, elem: S) void {
+        pub fn appendAssumeCapacity(self: *Self, elem: T) void {
             assert(self.len < self.capacity);
             self.len += 1;
             self.set(self.len - 1, elem);
@@ -213,7 +256,7 @@ pub fn MultiArrayList(comptime S: type) type {
         /// Remove and return the last element from the list.
         /// Asserts the list has at least one item.
         /// Invalidates pointers to fields of the removed element.
-        pub fn pop(self: *Self) S {
+        pub fn pop(self: *Self) T {
             const val = self.get(self.len - 1);
             self.len -= 1;
             return val;
@@ -222,7 +265,7 @@ pub fn MultiArrayList(comptime S: type) type {
         /// Remove and return the last element from the list, or
         /// return `null` if list is empty.
         /// Invalidates pointers to fields of the removed element, if any.
-        pub fn popOrNull(self: *Self) ?S {
+        pub fn popOrNull(self: *Self) ?T {
             if (self.len == 0) return null;
             return self.pop();
         }
@@ -231,7 +274,7 @@ pub fn MultiArrayList(comptime S: type) type {
         /// after and including the specified index back by one and
         /// sets the given index to the specified element.  May reallocate
         /// and invalidate iterators.
-        pub fn insert(self: *Self, gpa: Allocator, index: usize, elem: S) !void {
+        pub fn insert(self: *Self, gpa: Allocator, index: usize, elem: T) !void {
             try self.ensureUnusedCapacity(gpa, 1);
             self.insertAssumeCapacity(index, elem);
         }
@@ -240,10 +283,15 @@ pub fn MultiArrayList(comptime S: type) type {
         /// Shifts all elements after and including the specified index
         /// back by one and sets the given index to the specified element.
         /// Will not reallocate the array, does not invalidate iterators.
-        pub fn insertAssumeCapacity(self: *Self, index: usize, elem: S) void {
+        pub fn insertAssumeCapacity(self: *Self, index: usize, elem: T) void {
             assert(self.len < self.capacity);
             assert(index <= self.len);
             self.len += 1;
+            const entry = switch (@typeInfo(T)) {
+                .Struct => elem,
+                .Union => Elem.fromT(elem),
+                else => unreachable,
+            };
             const slices = self.slice();
             inline for (fields, 0..) |field_info, field_index| {
                 const field_slice = slices.items(@intToEnum(Field, field_index));
@@ -251,7 +299,7 @@ pub fn MultiArrayList(comptime S: type) type {
                 while (i > index) : (i -= 1) {
                     field_slice[i] = field_slice[i - 1];
                 }
-                field_slice[index] = @field(elem, field_info.name);
+                field_slice[index] = @field(entry, field_info.name);
             }
         }
 
@@ -304,7 +352,7 @@ pub fn MultiArrayList(comptime S: type) type {
 
             const other_bytes = gpa.alignedAlloc(
                 u8,
-                @alignOf(S),
+                @alignOf(Elem),
                 capacityInBytes(new_len),
             ) catch {
                 const self_slice = self.slice();
@@ -375,7 +423,7 @@ pub fn MultiArrayList(comptime S: type) type {
             assert(new_capacity >= self.len);
             const new_bytes = try gpa.alignedAlloc(
                 u8,
-                @alignOf(S),
+                @alignOf(Elem),
                 capacityInBytes(new_capacity),
             );
             if (self.len == 0) {
@@ -453,12 +501,12 @@ pub fn MultiArrayList(comptime S: type) type {
             return elem_bytes * capacity;
         }
 
-        fn allocatedBytes(self: Self) []align(@alignOf(S)) u8 {
+        fn allocatedBytes(self: Self) []align(@alignOf(Elem)) u8 {
             return self.bytes[0..capacityInBytes(self.capacity)];
         }
 
         fn FieldType(comptime field: Field) type {
-            return meta.fieldInfo(S, field).type;
+            return meta.fieldInfo(Elem, field).type;
         }
 
         const Entry = entry: {
@@ -479,7 +527,7 @@ pub fn MultiArrayList(comptime S: type) type {
         };
         /// This function is used in the debugger pretty formatters in tools/ to fetch the
         /// child field order and entry type to facilitate fancy debug printing for this type.
-        fn dbHelper(self: *Self, child: *S, field: *Field, entry: *Entry) void {
+        fn dbHelper(self: *Self, child: *Elem, field: *Field, entry: *Entry) void {
             _ = self;
             _ = child;
             _ = field;
@@ -719,3 +767,58 @@ test "insert elements" {
     try testing.expectEqualSlices(u8, &[_]u8{ 1, 2 }, list.items(.a));
     try testing.expectEqualSlices(u32, &[_]u32{ 2, 3 }, list.items(.b));
 }
+
+test "union" {
+    const ally = testing.allocator;
+
+    const Foo = union(enum) {
+        a: u32,
+        b: []const u8,
+    };
+
+    var list = MultiArrayList(Foo){};
+    defer list.deinit(ally);
+
+    try testing.expectEqual(@as(usize, 0), list.items(.tags).len);
+
+    try list.ensureTotalCapacity(ally, 2);
+
+    list.appendAssumeCapacity(.{ .a = 1 });
+    list.appendAssumeCapacity(.{ .b = "zigzag" });
+
+    try testing.expectEqualSlices(meta.Tag(Foo), list.items(.tags), &.{ .a, .b });
+    try testing.expectEqual(@as(usize, 2), list.items(.tags).len);
+
+    list.appendAssumeCapacity(.{ .b = "foobar" });
+    try testing.expectEqualStrings("zigzag", list.items(.data)[1].b);
+    try testing.expectEqualStrings("foobar", list.items(.data)[2].b);
+
+    // Add 6 more things to force a capacity increase.
+    for (0..6) |i| {
+        try list.append(ally, .{ .a = @intCast(u32, 4 + i) });
+    }
+
+    try testing.expectEqualSlices(
+        meta.Tag(Foo),
+        &.{ .a, .b, .b, .a, .a, .a, .a, .a, .a },
+        list.items(.tags),
+    );
+    try testing.expectEqual(list.get(0), .{ .a = 1 });
+    try testing.expectEqual(list.get(1), .{ .b = "zigzag" });
+    try testing.expectEqual(list.get(2), .{ .b = "foobar" });
+    try testing.expectEqual(list.get(3), .{ .a = 4 });
+    try testing.expectEqual(list.get(4), .{ .a = 5 });
+    try testing.expectEqual(list.get(5), .{ .a = 6 });
+    try testing.expectEqual(list.get(6), .{ .a = 7 });
+    try testing.expectEqual(list.get(7), .{ .a = 8 });
+    try testing.expectEqual(list.get(8), .{ .a = 9 });
+
+    list.shrinkAndFree(ally, 3);
+
+    try testing.expectEqual(@as(usize, 3), list.items(.tags).len);
+    try testing.expectEqualSlices(meta.Tag(Foo), list.items(.tags), &.{ .a, .b, .b });
+
+    try testing.expectEqual(list.get(0), .{ .a = 1 });
+    try testing.expectEqual(list.get(1), .{ .b = "zigzag" });
+    try testing.expectEqual(list.get(2), .{ .b = "foobar" });
+}
src/translate_c/ast.zig
@@ -845,7 +845,7 @@ const Context = struct {
         };
     }
 
-    fn addNode(c: *Context, elem: std.zig.Ast.NodeList.Elem) Allocator.Error!NodeIndex {
+    fn addNode(c: *Context, elem: std.zig.Ast.Node) Allocator.Error!NodeIndex {
         const result = @intCast(NodeIndex, c.nodes.len);
         try c.nodes.append(c.gpa, elem);
         return result;