Commit 15d5988e69
Changed files (4)
lib
std
lib/std/os/windows/ntdll.zig
@@ -229,6 +229,12 @@ pub extern "ntdll" fn RtlEqualUnicodeString(
CaseInSensitive: BOOLEAN,
) callconv(WINAPI) BOOLEAN;
+pub extern "NtDll" fn RtlUpcaseUnicodeString(
+ DestinationString: *UNICODE_STRING,
+ SourceString: *const UNICODE_STRING,
+ AllocateDestinationString: BOOLEAN,
+) callconv(WINAPI) NTSTATUS;
+
pub extern "ntdll" fn NtLockFile(
FileHandle: HANDLE,
Event: ?HANDLE,
lib/std/buf_map.zig
@@ -9,7 +9,7 @@ const testing = std.testing;
pub const BufMap = struct {
hash_map: BufMapHashMap,
- const BufMapHashMap = StringHashMap([]const u8);
+ pub const BufMapHashMap = StringHashMap([]const u8);
/// Create a BufMap backed by a specific allocator.
/// That allocator will be used for both backing allocations
lib/std/process.zig
@@ -2,7 +2,6 @@ const std = @import("std.zig");
const builtin = @import("builtin");
const os = std.os;
const fs = std.fs;
-const BufMap = std.BufMap;
const mem = std.mem;
const math = std.math;
const Allocator = mem.Allocator;
@@ -53,9 +52,385 @@ test "getCwdAlloc" {
testing.allocator.free(cwd);
}
-/// Caller owns resulting `BufMap`.
-pub fn getEnvMap(allocator: Allocator) !BufMap {
- var result = BufMap.init(allocator);
+/// EnvMap for Windows that handles Unicode-aware case insensitivity for lookups, while also
+/// providing the canonical environment variable names when iterating.
+///
+/// Allows for zero-allocation lookups (even though it needs to do UTF-8 -> UTF-16 -> uppercase
+/// conversions) by allocating a buffer large enough to fit the largest environment variable
+/// name, and using that when doing lookups (i.e. anything that overflows the buffer can be treated
+/// as the environment variable not being found).
+pub const EnvMapWindows = struct {
+ allocator: Allocator,
+ /// Keys are UTF-16le stored as []const u8
+ uppercased_map: std.StringHashMapUnmanaged(EnvValue),
+ /// Buffer for converting to uppercased UTF-16 on key lookups
+ /// Must call `reallocUppercaseBuf` before doing any lookups after a `put` call.
+ uppercase_buf_utf16: []u16 = &[_]u16{},
+ max_name_utf16_length: usize = 0,
+
+ pub const EnvValue = struct {
+ value: []const u8,
+ canonical_name: []const u8,
+ };
+
+ const Self = @This();
+
+ /// Deinitialize with `deinit`.
+ pub fn init(allocator: Allocator) Self {
+ return .{
+ .allocator = allocator,
+ .uppercased_map = std.StringHashMapUnmanaged(EnvValue){},
+ };
+ }
+
+ pub fn deinit(self: *Self) void {
+ var it = self.uppercased_map.iterator();
+ while (it.next()) |entry| {
+ self.allocator.free(entry.key_ptr.*);
+ self.allocator.free(entry.value_ptr.value);
+ self.allocator.free(entry.value_ptr.canonical_name);
+ }
+ self.uppercased_map.deinit(self.allocator);
+ self.allocator.free(self.uppercase_buf_utf16);
+ }
+
+ /// Increases the size of the uppercase buffer if the maximum name size has increased.
+ /// Must be called before any `get` calls after any number of `put` calls.
+ pub fn reallocUppercaseBuf(self: *Self) !void {
+ if (self.max_name_utf16_length > self.uppercase_buf_utf16.len) {
+ self.uppercase_buf_utf16 = try self.allocator.realloc(self.uppercase_buf_utf16, self.max_name_utf16_length);
+ }
+ }
+
+ /// Converts `src` to uppercase using `RtlUpcaseUnicodeString` and puts the result in `dest`.
+ /// Returns the length of the converted UTF-16 string. `dest.len` must be >= `src.len`.
+ ///
+ /// Note: As of now, RtlUpcaseUnicodeString does not seem to handle codepoints above 0x10000
+ /// (i.e. those that require a surrogate pair), so this function will always return a length
+ /// equal to `src.len`. However, if RtlUpcaseUnicodeString is updated to handle codepoints above
+ /// 0x10000, this property would still hold unless there are lowercase <-> uppercase conversions
+ /// that cross over the boundary between codepoints >= 0x10000 and < 0x10000.
+ /// TODO: Is it feasible that Unicode lowercase <-> uppercase conversions could cross that boundary?
+ fn uppercaseName(dest: []u16, src: []const u16) u16 {
+ assert(dest.len >= src.len);
+
+ const dest_bytes = @intCast(u16, dest.len * 2);
+ var dest_string = os.windows.UNICODE_STRING{
+ .Length = dest_bytes,
+ .MaximumLength = dest_bytes,
+ .Buffer = @intToPtr([*]u16, @ptrToInt(dest.ptr)),
+ };
+ const src_bytes = @intCast(u16, src.len * 2);
+ const src_string = os.windows.UNICODE_STRING{
+ .Length = src_bytes,
+ .MaximumLength = src_bytes,
+ .Buffer = @intToPtr([*]u16, @ptrToInt(src.ptr)),
+ };
+ const rc = os.windows.ntdll.RtlUpcaseUnicodeString(&dest_string, &src_string, os.windows.FALSE);
+ switch (rc) {
+ .SUCCESS => return dest_string.Length / 2,
+ else => unreachable, // we are not allocating, so no errors should be possible
+ }
+ }
+
+ /// Note: Does not realloc the uppercase buf to allow for calling put for many variables and
+ /// only allocating the uppercase buf afterwards.
+ pub fn putUtf8(self: *Self, name: []const u8, value: []const u8) !void {
+ const uppercased_len = len: {
+ const name_uppercased_utf16 = uppercased: {
+ var name_utf16_buf = try std.ArrayListAligned(u8, @alignOf(u16)).initCapacity(self.allocator, name.len);
+ errdefer name_utf16_buf.deinit();
+
+ var uppercased_len = try std.unicode.utf8ToUtf16LeWriter(name_utf16_buf.writer(), name);
+ assert(uppercased_len == name_utf16_buf.items.len);
+
+ break :uppercased name_utf16_buf.toOwnedSlice();
+ };
+ errdefer self.allocator.free(name_uppercased_utf16);
+
+ const name_canonical = try self.allocator.dupe(u8, name);
+ errdefer self.allocator.free(name_canonical);
+
+ const value_dupe = try self.allocator.dupe(u8, value);
+ errdefer self.allocator.free(value_dupe);
+
+ const get_or_put = try self.uppercased_map.getOrPut(self.allocator, name_uppercased_utf16);
+ if (get_or_put.found_existing) {
+ // note: this is only safe from UAF because the errdefer that frees this value above
+ // no longer has a possibility of being triggered after this point
+ self.allocator.free(name_uppercased_utf16);
+ self.allocator.free(get_or_put.value_ptr.value);
+ self.allocator.free(get_or_put.value_ptr.canonical_name);
+ } else {
+ get_or_put.key_ptr.* = name_uppercased_utf16;
+ }
+ get_or_put.value_ptr.value = value_dupe;
+ get_or_put.value_ptr.canonical_name = name_canonical;
+
+ break :len name_uppercased_utf16.len;
+ };
+
+ // The buffer for case conversion for key lookups will need to be as big as the largest
+ // key stored in the hash map.
+ self.max_name_utf16_length = @maximum(self.max_name_utf16_length, uppercased_len);
+ }
+
+ /// Asserts that the name does not already exist in the map.
+ /// Note: Does not realloc the uppercase buf to allow for calling put for many variables and
+ /// only allocating the uppercase buf afterwards.
+ pub fn putUtf16NoClobber(self: *Self, name_utf16: []const u16, value_utf16: []const u16) !void {
+ const uppercased_len = len: {
+ const name_canonical = try std.unicode.utf16leToUtf8Alloc(self.allocator, name_utf16);
+ errdefer self.allocator.free(name_canonical);
+
+ const value = try std.unicode.utf16leToUtf8Alloc(self.allocator, value_utf16);
+ errdefer self.allocator.free(value);
+
+ const name_uppercased_utf16 = try self.allocator.alloc(u16, name_utf16.len);
+ errdefer self.allocator.free(name_uppercased_utf16);
+
+ const uppercased_len = uppercaseName(name_uppercased_utf16, name_utf16);
+ assert(uppercased_len == name_uppercased_utf16.len);
+
+ try self.uppercased_map.putNoClobber(self.allocator, std.mem.sliceAsBytes(name_uppercased_utf16), EnvValue{
+ .value = value,
+ .canonical_name = name_canonical,
+ });
+ break :len name_uppercased_utf16.len;
+ };
+
+ // The buffer for case conversion for key lookups will need to be as big as the largest
+ // key stored in the hash map.
+ self.max_name_utf16_length = @maximum(self.max_name_utf16_length, uppercased_len);
+ }
+
+ /// Attempts to convert a UTF-8 name into a uppercased UTF-16le name for a lookup. If the
+ /// name cannot be converted, this function will return `null`.
+ fn utf8ToUppercasedUtf16(self: Self, name: []const u8) ?[]u16 {
+ const name_utf16: []u16 = to_utf16: {
+ var utf16_buf_stream = std.io.fixedBufferStream(std.mem.sliceAsBytes(self.uppercase_buf_utf16));
+ _ = std.unicode.utf8ToUtf16LeWriter(utf16_buf_stream.writer(), name) catch |err| switch (err) {
+ // If the buffer isn't large enough, we can treat that as 'env var not found', as we
+ // know anything too large for the buffer can't be found in the map.
+ error.NoSpaceLeft => return null,
+ // Anything with invalid UTF-8 will also not be found in the map, so treat that as
+ // 'env var not found' too
+ error.InvalidUtf8 => return null,
+ };
+ break :to_utf16 std.mem.bytesAsSlice(u16, utf16_buf_stream.getWritten());
+ };
+
+ // uppercase in place
+ const uppercased_len = uppercaseName(name_utf16, name_utf16);
+ assert(uppercased_len == name_utf16.len);
+
+ return name_utf16;
+ }
+
+ /// Returns true if an entry was found and deleted, false otherwise.
+ pub fn remove(self: *Self, name: []const u8) bool {
+ const name_utf16 = self.utf8ToUppercasedUtf16(name) orelse return false;
+ const kv = self.uppercased_map.fetchRemove(std.mem.sliceAsBytes(name_utf16)) orelse return false;
+ self.allocator.free(kv.key);
+ self.allocator.free(kv.value.value);
+ self.allocator.free(kv.value.canonical_name);
+ return true;
+ }
+
+ pub fn get(self: Self, name: []const u8) ?EnvValue {
+ const name_utf16 = self.utf8ToUppercasedUtf16(name) orelse return null;
+ return self.uppercased_map.get(std.mem.sliceAsBytes(name_utf16));
+ }
+
+ pub fn count(self: Self) EnvMap.Size {
+ return self.uppercased_map.count();
+ }
+
+ pub fn iterator(self: *const Self) Iterator {
+ return .{
+ .env_map = self,
+ .uppercased_map_iterator = self.uppercased_map.iterator(),
+ };
+ }
+
+ pub const Iterator = struct {
+ env_map: *const Self,
+ uppercased_map_iterator: std.StringHashMapUnmanaged(EnvValue).Iterator,
+
+ pub fn next(it: *Iterator) ?EnvMap.Entry {
+ if (it.uppercased_map_iterator.next()) |uppercased_entry| {
+ return EnvMap.Entry{
+ .name = uppercased_entry.value_ptr.canonical_name,
+ .value = uppercased_entry.value_ptr.value,
+ };
+ } else {
+ return null;
+ }
+ }
+ };
+};
+
+test "EnvMapWindows" {
+ if (builtin.os.tag != .windows) return error.SkipZigTest;
+
+ var env_map = EnvMapWindows.init(testing.allocator);
+ defer env_map.deinit();
+
+ // both put methods
+ try env_map.putUtf16NoClobber(std.unicode.utf8ToUtf16LeStringLiteral("Path"), std.unicode.utf8ToUtf16LeStringLiteral("something"));
+ try env_map.putUtf8("КИРИЛЛИЦА", "something else");
+ try env_map.reallocUppercaseBuf();
+
+ try testing.expectEqual(@as(EnvMap.Size, 2), env_map.count());
+
+ // unicode-aware case-insensitive lookups
+ try testing.expectEqualStrings("something", env_map.get("PATH").?.value);
+ try testing.expectEqualStrings("something else", env_map.get("кириллица").?.value);
+ try testing.expect(env_map.get("missing") == null);
+
+ // canonical names when iterating
+ var it = env_map.iterator();
+ var count: EnvMap.Size = 0;
+ while (it.next()) |entry| {
+ const is_an_expected_name = std.mem.eql(u8, "Path", entry.name) or std.mem.eql(u8, "КИРИЛЛИЦА", entry.name);
+ try testing.expect(is_an_expected_name);
+ count += 1;
+ }
+ try testing.expectEqual(@as(EnvMap.Size, 2), count);
+}
+
+pub const EnvMap = struct {
+ storage: StorageType,
+
+ pub const StorageType = switch (builtin.os.tag) {
+ .windows => EnvMapWindows,
+ else => std.BufMap,
+ };
+
+ /// Matches what BufMap uses for its internal HashMap Size
+ pub const Size = u32;
+
+ const Self = @This();
+
+ /// Deinitialize with `deinit`.
+ pub fn init(allocator: Allocator) Self {
+ return Self{ .storage = StorageType.init(allocator) };
+ }
+
+ pub fn deinit(self: *Self) void {
+ self.storage.deinit();
+ }
+
+ pub fn get(self: Self, name: []const u8) ?[]const u8 {
+ switch (builtin.os.tag) {
+ .windows => {
+ if (self.storage.get(name)) |entry| {
+ return entry.value;
+ } else {
+ return null;
+ }
+ },
+ else => return self.storage.get(name),
+ }
+ }
+
+ pub fn count(self: Self) Size {
+ return self.storage.count();
+ }
+
+ pub fn iterator(self: *const Self) Iterator {
+ return .{ .storage_iterator = self.storage.iterator() };
+ }
+
+ pub fn put(self: *Self, name: []const u8, value: []const u8) !void {
+ switch (builtin.os.tag) {
+ .windows => {
+ try self.storage.putUtf8(name, value);
+ try self.storage.reallocUppercaseBuf();
+ },
+ else => return self.storage.put(name, value),
+ }
+ }
+
+ pub fn remove(self: *Self, name: []const u8) void {
+ _ = self.storage.remove(name);
+ }
+
+ pub const Entry = struct {
+ name: []const u8,
+ value: []const u8,
+ };
+
+ pub const Iterator = struct {
+ storage_iterator: switch (builtin.os.tag) {
+ .windows => EnvMapWindows.Iterator,
+ else => std.BufMap.BufMapHashMap.Iterator,
+ },
+
+ pub fn next(it: *Iterator) ?Entry {
+ switch (builtin.os.tag) {
+ .windows => return it.storage_iterator.next(),
+ else => {
+ if (it.storage_iterator.next()) |entry| {
+ return Entry{
+ .name = entry.key_ptr.*,
+ .value = entry.value_ptr.*,
+ };
+ } else {
+ return null;
+ }
+ },
+ }
+ }
+ };
+};
+
+test "EnvMap" {
+ var env = EnvMap.init(testing.allocator);
+ defer env.deinit();
+
+ try env.put("SOMETHING_NEW", "hello");
+ try testing.expectEqualStrings("hello", env.get("SOMETHING_NEW").?);
+ try testing.expectEqual(@as(EnvMap.Size, 1), env.count());
+
+ // overwrite
+ try env.put("SOMETHING_NEW", "something");
+ try testing.expectEqualStrings("something", env.get("SOMETHING_NEW").?);
+ try testing.expectEqual(@as(EnvMap.Size, 1), env.count());
+
+ // a new longer name to test the Windows-specific conversion buffer
+ try env.put("SOMETHING_NEW_AND_LONGER", "1");
+ try testing.expectEqualStrings("1", env.get("SOMETHING_NEW_AND_LONGER").?);
+ try testing.expectEqual(@as(EnvMap.Size, 2), env.count());
+
+ // case insensitivity on Windows only
+ if (builtin.os.tag == .windows) {
+ try testing.expectEqualStrings("1", env.get("something_New_aNd_LONGER").?);
+ } else {
+ try testing.expect(null == env.get("something_New_aNd_LONGER"));
+ }
+
+ var it = env.iterator();
+ var count: EnvMap.Size = 0;
+ while (it.next()) |entry| {
+ const is_an_expected_name = std.mem.eql(u8, "SOMETHING_NEW", entry.name) or std.mem.eql(u8, "SOMETHING_NEW_AND_LONGER", entry.name);
+ try testing.expect(is_an_expected_name);
+ count += 1;
+ }
+ try testing.expectEqual(@as(EnvMap.Size, 2), count);
+
+ env.remove("SOMETHING_NEW");
+ try testing.expect(env.get("SOMETHING_NEW") == null);
+
+ try testing.expectEqual(@as(EnvMap.Size, 1), env.count());
+}
+
+/// Returns a snapshot of the environment variables of the current process.
+/// Any modifications to the resulting EnvMap will not be not reflected in the environment, and
+/// likewise, any future modifications to the environment will not be reflected in the EnvMap.
+/// Caller owns resulting `EnvMap` and should call its `deinit` fn when done.
+pub fn getEnvMap(allocator: Allocator) !EnvMap {
+ var result = EnvMap.init(allocator);
errdefer result.deinit();
if (builtin.os.tag == .windows) {
@@ -65,23 +440,27 @@ pub fn getEnvMap(allocator: Allocator) !BufMap {
while (ptr[i] != 0) {
const key_start = i;
+ // There are some special environment variables that start with =,
+ // so we need a special case to not treat = as a key/value separator
+ // if it's the first character.
+ // https://devblogs.microsoft.com/oldnewthing/20100506-00/?p=14133
+ if (ptr[key_start] == '=') i += 1;
+
while (ptr[i] != 0 and ptr[i] != '=') : (i += 1) {}
const key_w = ptr[key_start..i];
- const key = try std.unicode.utf16leToUtf8Alloc(allocator, key_w);
- errdefer allocator.free(key);
if (ptr[i] == '=') i += 1;
const value_start = i;
while (ptr[i] != 0) : (i += 1) {}
const value_w = ptr[value_start..i];
- const value = try std.unicode.utf16leToUtf8Alloc(allocator, value_w);
- errdefer allocator.free(value);
- i += 1; // skip over null byte
+ try result.storage.putUtf16NoClobber(key_w, value_w);
- try result.putMove(key, value);
+ i += 1; // skip over null byte
}
+
+ try result.storage.reallocUppercaseBuf();
return result;
} else if (builtin.os.tag == .wasi and !builtin.link_libc) {
var environ_count: usize = undefined;
@@ -140,8 +519,8 @@ pub fn getEnvMap(allocator: Allocator) !BufMap {
}
}
-test "os.getEnvMap" {
- var env = try getEnvMap(std.testing.allocator);
+test "getEnvMap" {
+ var env = try getEnvMap(testing.allocator);
defer env.deinit();
}
lib/std/unicode.zig
@@ -710,6 +710,29 @@ pub fn utf8ToUtf16Le(utf16le: []u16, utf8: []const u8) !usize {
return dest_i;
}
+pub fn utf8ToUtf16LeWriter(writer: anytype, utf8: []const u8) !usize {
+ var src_i: usize = 0;
+ var bytes_written: usize = 0;
+ while (src_i < utf8.len) {
+ const n = utf8ByteSequenceLength(utf8[src_i]) catch return error.InvalidUtf8;
+ const next_src_i = src_i + n;
+ const codepoint = utf8Decode(utf8[src_i..next_src_i]) catch return error.InvalidUtf8;
+ if (codepoint < 0x10000) {
+ const short = @intCast(u16, codepoint);
+ try writer.writeIntLittle(u16, short);
+ bytes_written += 2;
+ } else {
+ const high = @intCast(u16, (codepoint - 0x10000) >> 10) + 0xD800;
+ const low = @intCast(u16, codepoint & 0x3FF) + 0xDC00;
+ try writer.writeIntLittle(u16, high);
+ try writer.writeIntLittle(u16, low);
+ bytes_written += 4;
+ }
+ src_i = next_src_i;
+ }
+ return bytes_written;
+}
+
test "utf8ToUtf16Le" {
var utf16le: [2]u16 = [_]u16{0} ** 2;
{