Commit cd9af0f286

Pyrolistical <pyrogx1133@gmail.com>
2022-12-11 18:10:54
std: add EnumMultiSet
1 parent 05890a1
Changed files (1)
lib
lib/std/enums.zig
@@ -304,6 +304,346 @@ pub fn EnumMap(comptime E: type, comptime V: type) type {
     return IndexedMap(EnumIndexer(E), V, mixin.EnumMapExt);
 }
 
+/// A multiset of enum elements up to a count of usize. Backed
+/// by an EnumArray. This type does no dynamic allocation and can
+/// be copied by value.
+pub fn EnumMultiset(comptime E: type) type {
+    return BoundedEnumMultiset(E, usize);
+}
+
+/// A multiset of enum elements up to CountSize. Backed by an
+/// EnumArray. This type does no dynamic allocation and can be
+/// copied by value.
+pub fn BoundedEnumMultiset(comptime E: type, comptime CountSize: type) type {
+    return struct {
+        const Self = @This();
+
+        counts: EnumArray(E, CountSize),
+
+        /// Initializes the multiset using a struct of counts.
+        pub fn init(init_counts: EnumFieldStruct(E, CountSize, 0)) Self {
+            var self = initWithCount(0);
+            inline for (@typeInfo(E).Enum.fields) |field| {
+                const c = @field(init_counts, field.name);
+                const key = @intToEnum(E, field.value);
+                self.counts.set(key, c);
+            }
+            return self;
+        }
+
+        /// Initializes the multiset with a count of zero.
+        pub fn initEmpty() Self {
+            return initWithCount(0);
+        }
+
+        /// Initializes the multiset with all keys at the
+        /// same count.
+        pub fn initWithCount(comptime c: CountSize) Self {
+            return .{
+                .counts = EnumArray(E, CountSize).initDefault(c, .{}),
+            };
+        }
+
+        /// Returns the total number of key counts in the multiset.
+        pub fn count(self: Self) usize {
+            var sum: usize = 0;
+            for (self.counts.values) |c| {
+                sum += c;
+            }
+            return sum;
+        }
+
+        /// Checks if at least one key in multiset.
+        pub fn contains(self: Self, key: E) bool {
+            return self.counts.get(key) > 0;
+        }
+
+        /// Removes all instance of a key from multiset. Same as
+        /// setCount(key, 0).
+        pub fn removeAll(self: *Self, key: E) void {
+            return self.counts.set(key, 0);
+        }
+
+        /// Increases the key count by given amount. Caller asserts
+        /// operation will not overflow.
+        pub fn addAssertSafe(self: *Self, key: E, c: CountSize) void {
+            self.counts.getPtr(key).* += c;
+        }
+
+        /// Increases the key count by given amount.
+        pub fn add(self: *Self, key: E, c: CountSize) error{Overflow}!void {
+            self.counts.set(key, try std.math.add(CountSize, self.counts.get(key), c));
+        }
+
+        /// Decreases the key count by given amount. If amount is
+        /// greater than the number of keys in multset, then key count
+        /// will be set to zero.
+        pub fn remove(self: *Self, key: E, c: CountSize) void {
+            self.counts.getPtr(key).* -= @min(self.getCount(key), c);
+        }
+
+        /// Returns the count for a key.
+        pub fn getCount(self: Self, key: E) CountSize {
+            return self.counts.get(key);
+        }
+
+        /// Set the count for a key.
+        pub fn setCount(self: *Self, key: E, c: CountSize) void {
+            self.counts.set(key, c);
+        }
+
+        /// Increases the all key counts by given multiset. Caller
+        /// asserts operation will not overflow any key.
+        pub fn addSetAssertSafe(self: *Self, other: Self) void {
+            inline for (@typeInfo(E).Enum.fields) |field| {
+                const key = @intToEnum(E, field.value);
+                self.addAssertSafe(key, other.getCount(key));
+            }
+        }
+
+        /// Increases the all key counts by given multiset.
+        pub fn addSet(self: *Self, other: Self) error{Overflow}!void {
+            inline for (@typeInfo(E).Enum.fields) |field| {
+                const key = @intToEnum(E, field.value);
+                try self.add(key, other.getCount(key));
+            }
+        }
+
+        /// Deccreases the all key counts by given multiset. If
+        /// the given multiset has more key counts than this,
+        /// then that key will have a key count of zero.
+        pub fn removeSet(self: *Self, other: Self) void {
+            inline for (@typeInfo(E).Enum.fields) |field| {
+                const key = @intToEnum(E, field.value);
+                self.remove(key, other.getCount(key));
+            }
+        }
+
+        /// Returns true iff all key counts are the same as
+        /// given multiset.
+        pub fn eql(self: Self, other: Self) bool {
+            inline for (@typeInfo(E).Enum.fields) |field| {
+                const key = @intToEnum(E, field.value);
+                if (self.getCount(key) != other.getCount(key)) {
+                    return false;
+                }
+            }
+            return true;
+        }
+
+        /// Returns a multiset with the total key count of this
+        /// multiset and the other multiset. Caller asserts
+        /// operation will not overflow any key.
+        pub fn plusAssertSafe(self: Self, other: Self) Self {
+            var result = self;
+            result.addSetAssertSafe(other);
+            return result;
+        }
+
+        /// Returns a multiset with the total key count of this
+        /// multiset and the other multiset.
+        pub fn plus(self: Self, other: Self) error{Overflow}!Self {
+            var result = self;
+            try result.addSet(other);
+            return result;
+        }
+
+        /// Returns a multiset with the key count of this
+        /// multiset minus the corresponding key count in the
+        /// other multiset. If the other multiset contains
+        /// more key count than this set, that key will have
+        /// a count of zero.
+        pub fn minus(self: Self, other: Self) Self {
+            var result = self;
+            result.removeSet(other);
+            return result;
+        }
+
+        pub const Entry = EnumArray(E, CountSize).Entry;
+        pub const Iterator = EnumArray(E, CountSize).Iterator;
+
+        /// Returns an iterator over this multiset. Keys with zero
+        /// counts are included. Modifications to the set during
+        /// iteration may or may not be observed by the iterator,
+        /// but will not invalidate it.
+        pub fn iterator(self: *Self) Iterator {
+            return self.counts.iterator();
+        }
+    };
+}
+
+test "EnumMultiset" {
+    const Ball = enum { red, green, blue };
+
+    const empty = EnumMultiset(Ball).initEmpty();
+    const r0_g1_b2 = EnumMultiset(Ball).init(.{
+        .red = 0,
+        .green = 1,
+        .blue = 2,
+    });
+    const ten_of_each = EnumMultiset(Ball).initWithCount(10);
+
+    try testing.expectEqual(empty.count(), 0);
+    try testing.expectEqual(r0_g1_b2.count(), 3);
+    try testing.expectEqual(ten_of_each.count(), 30);
+
+    try testing.expect(!empty.contains(.red));
+    try testing.expect(!empty.contains(.green));
+    try testing.expect(!empty.contains(.blue));
+
+    try testing.expect(!r0_g1_b2.contains(.red));
+    try testing.expect(r0_g1_b2.contains(.green));
+    try testing.expect(r0_g1_b2.contains(.blue));
+
+    try testing.expect(ten_of_each.contains(.red));
+    try testing.expect(ten_of_each.contains(.green));
+    try testing.expect(ten_of_each.contains(.blue));
+
+    {
+        var copy = ten_of_each;
+        copy.removeAll(.red);
+        try testing.expect(!copy.contains(.red));
+
+        // removeAll second time does nothing
+        copy.removeAll(.red);
+        try testing.expect(!copy.contains(.red));
+    }
+
+    {
+        var copy = ten_of_each;
+        copy.addAssertSafe(.red, 6);
+        try testing.expectEqual(copy.getCount(.red), 16);
+    }
+
+    {
+        var copy = ten_of_each;
+        try copy.add(.red, 6);
+        try testing.expectEqual(copy.getCount(.red), 16);
+
+        try testing.expectError(error.Overflow, copy.add(.red, std.math.maxInt(usize)));
+    }
+
+    {
+        var copy = ten_of_each;
+        copy.remove(.red, 4);
+        try testing.expectEqual(copy.getCount(.red), 6);
+
+        // subtracting more it contains does not underflow
+        copy.remove(.green, 14);
+        try testing.expectEqual(copy.getCount(.green), 0);
+    }
+
+    try testing.expectEqual(empty.getCount(.green), 0);
+    try testing.expectEqual(r0_g1_b2.getCount(.green), 1);
+    try testing.expectEqual(ten_of_each.getCount(.green), 10);
+
+    {
+        var copy = empty;
+        copy.setCount(.red, 6);
+        try testing.expectEqual(copy.getCount(.red), 6);
+    }
+
+    {
+        var copy = r0_g1_b2;
+        copy.addSetAssertSafe(ten_of_each);
+        try testing.expectEqual(copy.getCount(.red), 10);
+        try testing.expectEqual(copy.getCount(.green), 11);
+        try testing.expectEqual(copy.getCount(.blue), 12);
+    }
+
+    {
+        var copy = r0_g1_b2;
+        try copy.addSet(ten_of_each);
+        try testing.expectEqual(copy.getCount(.red), 10);
+        try testing.expectEqual(copy.getCount(.green), 11);
+        try testing.expectEqual(copy.getCount(.blue), 12);
+
+        const full = EnumMultiset(Ball).initWithCount(std.math.maxInt(usize));
+        try testing.expectError(error.Overflow, copy.addSet(full));
+    }
+
+    {
+        var copy = ten_of_each;
+        copy.removeSet(r0_g1_b2);
+        try testing.expectEqual(copy.getCount(.red), 10);
+        try testing.expectEqual(copy.getCount(.green), 9);
+        try testing.expectEqual(copy.getCount(.blue), 8);
+
+        copy.removeSet(ten_of_each);
+        try testing.expectEqual(copy.getCount(.red), 0);
+        try testing.expectEqual(copy.getCount(.green), 0);
+        try testing.expectEqual(copy.getCount(.blue), 0);
+    }
+
+    try testing.expect(empty.eql(empty));
+    try testing.expect(r0_g1_b2.eql(r0_g1_b2));
+    try testing.expect(ten_of_each.eql(ten_of_each));
+    try testing.expect(!empty.eql(r0_g1_b2));
+    try testing.expect(!r0_g1_b2.eql(ten_of_each));
+    try testing.expect(!ten_of_each.eql(empty));
+
+    {
+        const result = r0_g1_b2.plusAssertSafe(ten_of_each);
+        try testing.expectEqual(result.getCount(.red), 10);
+        try testing.expectEqual(result.getCount(.green), 11);
+        try testing.expectEqual(result.getCount(.blue), 12);
+    }
+
+    {
+        const result = try r0_g1_b2.plus(ten_of_each);
+        try testing.expectEqual(result.getCount(.red), 10);
+        try testing.expectEqual(result.getCount(.green), 11);
+        try testing.expectEqual(result.getCount(.blue), 12);
+
+        const full = EnumMultiset(Ball).initWithCount(std.math.maxInt(usize));
+        try testing.expectError(error.Overflow, result.plus(full));
+    }
+
+    {
+        const result = ten_of_each.minus(r0_g1_b2);
+        try testing.expectEqual(result.getCount(.red), 10);
+        try testing.expectEqual(result.getCount(.green), 9);
+        try testing.expectEqual(result.getCount(.blue), 8);
+    }
+
+    {
+        const result = ten_of_each.minus(r0_g1_b2).minus(ten_of_each);
+        try testing.expectEqual(result.getCount(.red), 0);
+        try testing.expectEqual(result.getCount(.green), 0);
+        try testing.expectEqual(result.getCount(.blue), 0);
+    }
+
+    {
+        var copy = empty;
+        var it = copy.iterator();
+        var entry = it.next().?;
+        try testing.expectEqual(entry.key, .red);
+        try testing.expectEqual(entry.value.*, 0);
+        entry = it.next().?;
+        try testing.expectEqual(entry.key, .green);
+        try testing.expectEqual(entry.value.*, 0);
+        entry = it.next().?;
+        try testing.expectEqual(entry.key, .blue);
+        try testing.expectEqual(entry.value.*, 0);
+        try testing.expectEqual(it.next(), null);
+    }
+
+    {
+        var copy = r0_g1_b2;
+        var it = copy.iterator();
+        var entry = it.next().?;
+        try testing.expectEqual(entry.key, .red);
+        try testing.expectEqual(entry.value.*, 0);
+        entry = it.next().?;
+        try testing.expectEqual(entry.key, .green);
+        try testing.expectEqual(entry.value.*, 1);
+        entry = it.next().?;
+        try testing.expectEqual(entry.key, .blue);
+        try testing.expectEqual(entry.value.*, 2);
+        try testing.expectEqual(it.next(), null);
+    }
+}
+
 /// An array keyed by an enum, backed by a dense array.
 /// If the enum is not dense, a mapping will be constructed from
 /// enum values to dense indices.  This type does no dynamic