Commit c9613e3d5c

Ryan Liptak <squeek502@hotmail.com>
2023-09-14 11:35:39
ComptimeStringMap: Add version that takes an equality function
This will allow users to construct e.g. a ComptimeStringMap that uses case-insensitive ASCII comparison. Note: the previous ComptimeStringMap API is unchanged (i.e. this does not break any existing code).
1 parent 6998233
Changed files (2)
lib/std/comptime_string_map.zig
@@ -7,7 +7,42 @@ const mem = std.mem;
 ///
 /// `kvs_list` expects a list of `struct { []const u8, V }` (key-value pair) tuples.
 /// You can pass `struct { []const u8 }` (only keys) tuples if `V` is `void`.
-pub fn ComptimeStringMap(comptime V: type, comptime kvs_list: anytype) type {
+pub fn ComptimeStringMap(
+    comptime V: type,
+    comptime kvs_list: anytype,
+) type {
+    return ComptimeStringMapWithEql(V, kvs_list, defaultEql);
+}
+
+/// Like `std.mem.eql`, but takes advantage of the fact that the lengths
+/// of `a` and `b` are known to be equal.
+pub fn defaultEql(a: []const u8, b: []const u8) bool {
+    if (a.ptr == b.ptr) return true;
+    for (a, b) |a_elem, b_elem| {
+        if (a_elem != b_elem) return false;
+    }
+    return true;
+}
+
+/// Like `std.ascii.eqlIgnoreCase` but takes advantage of the fact that
+/// the lengths of `a` and `b` are known to be equal.
+pub fn eqlAsciiIgnoreCase(a: []const u8, b: []const u8) bool {
+    if (a.ptr == b.ptr) return true;
+    for (a, b) |a_c, b_c| {
+        if (std.ascii.toLower(a_c) != std.ascii.toLower(b_c)) return false;
+    }
+    return true;
+}
+
+/// ComptimeStringMap, but accepts an equality function (`eql`).
+/// The `eql` function is only called to determine the equality
+/// of equal length strings. Any strings that are not equal length
+/// are never compared using the `eql` function.
+pub fn ComptimeStringMapWithEql(
+    comptime V: type,
+    comptime kvs_list: anytype,
+    comptime eql: fn (a: []const u8, b: []const u8) bool,
+) type {
     const precomputed = comptime blk: {
         @setEvalBranchQuota(1500);
         const KV = struct {
@@ -76,7 +111,7 @@ pub fn ComptimeStringMap(comptime V: type, comptime kvs_list: anytype) type {
                 const kv = precomputed.sorted_kvs[i];
                 if (kv.key.len != str.len)
                     return null;
-                if (mem.eql(u8, kv.key, str))
+                if (eql(kv.key, str))
                     return kv.value;
                 i += 1;
                 if (i >= precomputed.sorted_kvs.len)
@@ -180,3 +215,20 @@ fn testSet(comptime map: anytype) !void {
     try std.testing.expect(!map.has("missing"));
     try std.testing.expect(map.has("these"));
 }
+
+test "ComptimeStringMapWithEql" {
+    const map = ComptimeStringMapWithEql(TestEnum, .{
+        .{ "these", .D },
+        .{ "have", .A },
+        .{ "nothing", .B },
+        .{ "incommon", .C },
+        .{ "samelen", .E },
+    }, eqlAsciiIgnoreCase);
+
+    try testMap(map);
+    try std.testing.expectEqual(TestEnum.A, map.get("HAVE").?);
+    try std.testing.expectEqual(TestEnum.E, map.get("SameLen").?);
+    try std.testing.expect(null == map.get("SameLength"));
+
+    try std.testing.expect(map.has("ThESe"));
+}
lib/std/std.zig
@@ -16,7 +16,8 @@ pub const BufMap = @import("buf_map.zig").BufMap;
 pub const BufSet = @import("buf_set.zig").BufSet;
 /// Deprecated: use `process.Child`.
 pub const ChildProcess = @import("child_process.zig").ChildProcess;
-pub const ComptimeStringMap = @import("comptime_string_map.zig").ComptimeStringMap;
+pub const ComptimeStringMap = comptime_string_map.ComptimeStringMap;
+pub const ComptimeStringMapWithEql = comptime_string_map.ComptimeStringMapWithEql;
 pub const DoublyLinkedList = @import("linked_list.zig").DoublyLinkedList;
 pub const DynLib = @import("dynamic_library.zig").DynLib;
 pub const DynamicBitSet = bit_set.DynamicBitSet;
@@ -74,6 +75,8 @@ pub const coff = @import("coff.zig");
 /// Compression algorithms such as zlib, zstd, etc.
 pub const compress = @import("compress.zig");
 
+pub const comptime_string_map = @import("comptime_string_map.zig");
+
 /// Cryptography.
 pub const crypto = @import("crypto.zig");