Commit 4e212f1650

Eric Joldasov <bratishkaerik@getgoogleoff.me>
2023-11-17 18:36:44
std.enums: allow non-exhaustive enums in EnumIndexer and make `count` comptime_int instead of usize
Seems like this restriction was actual when Ziglang had extern enums, but now it's not neccessary and can be lifted. It was present since original PR which introduced std.enums, https://www.github.com/ziglang/zig/pull/8171. See also: https://ziggit.dev/t/catching-invalid-enum-value-errors/2206/11 * Make `count` comptime_int instead of usize With previous type, creating EnumIndexer for enum(usize) and enum(isize) would cause compile error since `count` could not store maxInt(usize) + 1. Now it can store it and reflects len field from std.builtin.Type.Array (most common use case of count field inside std.enums functions is creating arrays). Signed-off-by: Eric Joldasov <bratishkaerik@getgoogleoff.me>
1 parent 7b99189
Changed files (1)
lib
lib/std/enums.zig
@@ -1267,7 +1267,7 @@ pub fn IndexedArray(comptime I: type, comptime V: type, comptime Ext: ?fn (type)
 ///     /// The key type which this indexer converts to indices
 ///     pub const Key: type,
 ///     /// The number of indexes in the dense mapping
-///     pub const count: usize,
+///     pub const count: comptime_int,
 ///     /// Converts from a key to an index
 ///     pub fn indexOf(Key) usize;
 ///     /// Converts from an index to a key
@@ -1278,8 +1278,8 @@ pub fn ensureIndexer(comptime T: type) void {
     comptime {
         if (!@hasDecl(T, "Key")) @compileError("Indexer must have decl Key: type.");
         if (@TypeOf(T.Key) != type) @compileError("Indexer.Key must be a type.");
-        if (!@hasDecl(T, "count")) @compileError("Indexer must have decl count: usize.");
-        if (@TypeOf(T.count) != usize) @compileError("Indexer.count must be a usize.");
+        if (!@hasDecl(T, "count")) @compileError("Indexer must have decl count: comptime_int.");
+        if (@TypeOf(T.count) != comptime_int) @compileError("Indexer.count must be a comptime_int.");
         if (!@hasDecl(T, "indexOf")) @compileError("Indexer.indexOf must be a fn (Key) usize.");
         if (@TypeOf(T.indexOf) != fn (T.Key) usize) @compileError("Indexer must have decl indexOf: fn (Key) usize.");
         if (!@hasDecl(T, "keyForIndex")) @compileError("Indexer must have decl keyForIndex: fn (usize) Key.");
@@ -1290,7 +1290,7 @@ pub fn ensureIndexer(comptime T: type) void {
 test "std.enums.ensureIndexer" {
     ensureIndexer(struct {
         pub const Key = u32;
-        pub const count: usize = 8;
+        pub const count: comptime_int = 8;
         pub fn indexOf(k: Key) usize {
             return @as(usize, @intCast(k));
         }
@@ -1302,7 +1302,36 @@ test "std.enums.ensureIndexer" {
 
 pub fn EnumIndexer(comptime E: type) type {
     if (!@typeInfo(E).Enum.is_exhaustive) {
-        @compileError("Cannot create an enum indexer for a non-exhaustive enum.");
+        const BackingInt = @typeInfo(E).Enum.tag_type;
+        if (@bitSizeOf(BackingInt) > @bitSizeOf(usize))
+            @compileError("Cannot create an enum indexer for a given non-exhaustive enum, tag_type is larger than usize.");
+
+        return struct {
+            pub const Key: type = E;
+
+            const backing_int_sign = @typeInfo(BackingInt).Int.signedness;
+            const min_value = std.math.minInt(BackingInt);
+            const max_value = std.math.maxInt(BackingInt);
+
+            const RangeType = std.meta.Int(.unsigned, @bitSizeOf(BackingInt));
+            pub const count: comptime_int = std.math.maxInt(RangeType) + 1;
+
+            pub fn indexOf(e: E) usize {
+                if (backing_int_sign == .unsigned)
+                    return @intFromEnum(e);
+
+                return if (@intFromEnum(e) < 0)
+                    @intCast(@intFromEnum(e) - min_value)
+                else
+                    @as(RangeType, -min_value) + @as(RangeType, @intCast(@intFromEnum(e)));
+            }
+            pub fn keyForIndex(i: usize) E {
+                if (backing_int_sign == .unsigned)
+                    return @enumFromInt(i);
+
+                return @enumFromInt(@as(std.meta.Int(.signed, @bitSizeOf(RangeType) + 1), @intCast(i)) + min_value);
+            }
+        };
     }
 
     const const_fields = std.meta.fields(E);
@@ -1312,7 +1341,7 @@ pub fn EnumIndexer(comptime E: type) type {
     if (fields_len == 0) {
         return struct {
             pub const Key = E;
-            pub const count: usize = 0;
+            pub const count: comptime_int = 0;
             pub fn indexOf(e: E) usize {
                 _ = e;
                 unreachable;
@@ -1343,7 +1372,7 @@ pub fn EnumIndexer(comptime E: type) type {
     if (max - min == fields.len - 1) {
         return struct {
             pub const Key = E;
-            pub const count = fields_len;
+            pub const count: comptime_int = fields_len;
             pub fn indexOf(e: E) usize {
                 return @as(usize, @intCast(@intFromEnum(e) - min));
             }
@@ -1361,7 +1390,7 @@ pub fn EnumIndexer(comptime E: type) type {
 
     return struct {
         pub const Key = E;
-        pub const count = fields_len;
+        pub const count: comptime_int = fields_len;
         pub fn indexOf(e: E) usize {
             for (keys, 0..) |k, i| {
                 if (k == e) return i;
@@ -1374,12 +1403,61 @@ pub fn EnumIndexer(comptime E: type) type {
     };
 }
 
+test "EnumIndexer non-exhaustive" {
+    const backing_ints = [_]type{
+        i1,
+        i2,
+        i3,
+        i4,
+        i8,
+        i16,
+        std.meta.Int(.signed, @bitSizeOf(isize) - 1),
+        isize,
+        u1,
+        u2,
+        u3,
+        u4,
+        u16,
+        std.meta.Int(.unsigned, @bitSizeOf(usize) - 1),
+        usize,
+    };
+    inline for (backing_ints) |BackingInt| {
+        const E = enum(BackingInt) {
+            number_zero_tag = 0,
+            _,
+        };
+        const Indexer = EnumIndexer(E);
+        ensureIndexer(Indexer);
+
+        const min_tag: E = @enumFromInt(std.math.minInt(BackingInt));
+        const max_tag: E = @enumFromInt(std.math.maxInt(BackingInt));
+
+        const RangedType = std.meta.Int(.unsigned, @bitSizeOf(BackingInt));
+        const max_index: comptime_int = std.math.maxInt(RangedType);
+        const number_zero_tag_index: usize = switch (@typeInfo(BackingInt).Int.signedness) {
+            .unsigned => 0,
+            .signed => std.math.divCeil(comptime_int, max_index, 2) catch unreachable,
+        };
+
+        try testing.expectEqual(E, Indexer.Key);
+        try testing.expectEqual(max_index + 1, Indexer.count);
+
+        try testing.expectEqual(@as(usize, 0), Indexer.indexOf(min_tag));
+        try testing.expectEqual(number_zero_tag_index, Indexer.indexOf(E.number_zero_tag));
+        try testing.expectEqual(@as(usize, max_index), Indexer.indexOf(max_tag));
+
+        try testing.expectEqual(min_tag, Indexer.keyForIndex(0));
+        try testing.expectEqual(E.number_zero_tag, Indexer.keyForIndex(number_zero_tag_index));
+        try testing.expectEqual(max_tag, Indexer.keyForIndex(max_index));
+    }
+}
+
 test "std.enums.EnumIndexer dense zeroed" {
     const E = enum(u2) { b = 1, a = 0, c = 2 };
     const Indexer = EnumIndexer(E);
     ensureIndexer(Indexer);
     try testing.expectEqual(E, Indexer.Key);
-    try testing.expectEqual(@as(usize, 3), Indexer.count);
+    try testing.expectEqual(3, Indexer.count);
 
     try testing.expectEqual(@as(usize, 0), Indexer.indexOf(.a));
     try testing.expectEqual(@as(usize, 1), Indexer.indexOf(.b));
@@ -1395,7 +1473,7 @@ test "std.enums.EnumIndexer dense positive" {
     const Indexer = EnumIndexer(E);
     ensureIndexer(Indexer);
     try testing.expectEqual(E, Indexer.Key);
-    try testing.expectEqual(@as(usize, 3), Indexer.count);
+    try testing.expectEqual(3, Indexer.count);
 
     try testing.expectEqual(@as(usize, 0), Indexer.indexOf(.a));
     try testing.expectEqual(@as(usize, 1), Indexer.indexOf(.b));
@@ -1411,7 +1489,7 @@ test "std.enums.EnumIndexer dense negative" {
     const Indexer = EnumIndexer(E);
     ensureIndexer(Indexer);
     try testing.expectEqual(E, Indexer.Key);
-    try testing.expectEqual(@as(usize, 3), Indexer.count);
+    try testing.expectEqual(3, Indexer.count);
 
     try testing.expectEqual(@as(usize, 0), Indexer.indexOf(.a));
     try testing.expectEqual(@as(usize, 1), Indexer.indexOf(.b));
@@ -1427,7 +1505,7 @@ test "std.enums.EnumIndexer sparse" {
     const Indexer = EnumIndexer(E);
     ensureIndexer(Indexer);
     try testing.expectEqual(E, Indexer.Key);
-    try testing.expectEqual(@as(usize, 3), Indexer.count);
+    try testing.expectEqual(3, Indexer.count);
 
     try testing.expectEqual(@as(usize, 0), Indexer.indexOf(.a));
     try testing.expectEqual(@as(usize, 1), Indexer.indexOf(.b));
@@ -1443,5 +1521,5 @@ test "std.enums.EnumIndexer empty" {
     const Indexer = EnumIndexer(E);
     ensureIndexer(Indexer);
     try testing.expectEqual(E, Indexer.Key);
-    try testing.expectEqual(@as(usize, 0), Indexer.count);
+    try testing.expectEqual(0, Indexer.count);
 }