Commit 15d5988e69

Ryan Liptak <squeek502@hotmail.com>
2022-01-17 05:11:08
Add `process.EnvMap`, a platform-independent environment variable map
EnvMap provides the same API as the previously used BufMap (besides `putMove` and `getPtr`), so usage sites of `getEnvMap` can usually remain unchanged. For non-Windows, EnvMap is a wrapper around BufMap. On Windows, it uses a new EnvMapWindows to handle some Windows-specific behavior: - Lookups use Unicode-aware case insensitivity (but `get` cannot return an error because EnvMapWindows has an internal buffer to use for lookup conversions) - Canonical names are returned when iterating the EnvMap Fixes #10561, closes #4603
1 parent d383b94
Changed files (4)
lib/std/os/windows/ntdll.zig
@@ -229,6 +229,12 @@ pub extern "ntdll" fn RtlEqualUnicodeString(
     CaseInSensitive: BOOLEAN,
 ) callconv(WINAPI) BOOLEAN;
 
+pub extern "NtDll" fn RtlUpcaseUnicodeString(
+    DestinationString: *UNICODE_STRING,
+    SourceString: *const UNICODE_STRING,
+    AllocateDestinationString: BOOLEAN,
+) callconv(WINAPI) NTSTATUS;
+
 pub extern "ntdll" fn NtLockFile(
     FileHandle: HANDLE,
     Event: ?HANDLE,
lib/std/buf_map.zig
@@ -9,7 +9,7 @@ const testing = std.testing;
 pub const BufMap = struct {
     hash_map: BufMapHashMap,
 
-    const BufMapHashMap = StringHashMap([]const u8);
+    pub const BufMapHashMap = StringHashMap([]const u8);
 
     /// Create a BufMap backed by a specific allocator.
     /// That allocator will be used for both backing allocations
lib/std/process.zig
@@ -2,7 +2,6 @@ const std = @import("std.zig");
 const builtin = @import("builtin");
 const os = std.os;
 const fs = std.fs;
-const BufMap = std.BufMap;
 const mem = std.mem;
 const math = std.math;
 const Allocator = mem.Allocator;
@@ -53,9 +52,385 @@ test "getCwdAlloc" {
     testing.allocator.free(cwd);
 }
 
-/// Caller owns resulting `BufMap`.
-pub fn getEnvMap(allocator: Allocator) !BufMap {
-    var result = BufMap.init(allocator);
+/// EnvMap for Windows that handles Unicode-aware case insensitivity for lookups, while also
+/// providing the canonical environment variable names when iterating.
+///
+/// Allows for zero-allocation lookups (even though it needs to do UTF-8 -> UTF-16 -> uppercase
+/// conversions) by allocating a buffer large enough to fit the largest environment variable
+/// name, and using that when doing lookups (i.e. anything that overflows the buffer can be treated
+/// as the environment variable not being found).
+pub const EnvMapWindows = struct {
+    allocator: Allocator,
+    /// Keys are UTF-16le stored as []const u8
+    uppercased_map: std.StringHashMapUnmanaged(EnvValue),
+    /// Buffer for converting to uppercased UTF-16 on key lookups
+    /// Must call `reallocUppercaseBuf` before doing any lookups after a `put` call.
+    uppercase_buf_utf16: []u16 = &[_]u16{},
+    max_name_utf16_length: usize = 0,
+
+    pub const EnvValue = struct {
+        value: []const u8,
+        canonical_name: []const u8,
+    };
+
+    const Self = @This();
+
+    /// Deinitialize with `deinit`.
+    pub fn init(allocator: Allocator) Self {
+        return .{
+            .allocator = allocator,
+            .uppercased_map = std.StringHashMapUnmanaged(EnvValue){},
+        };
+    }
+
+    pub fn deinit(self: *Self) void {
+        var it = self.uppercased_map.iterator();
+        while (it.next()) |entry| {
+            self.allocator.free(entry.key_ptr.*);
+            self.allocator.free(entry.value_ptr.value);
+            self.allocator.free(entry.value_ptr.canonical_name);
+        }
+        self.uppercased_map.deinit(self.allocator);
+        self.allocator.free(self.uppercase_buf_utf16);
+    }
+
+    /// Increases the size of the uppercase buffer if the maximum name size has increased.
+    /// Must be called before any `get` calls after any number of `put` calls.
+    pub fn reallocUppercaseBuf(self: *Self) !void {
+        if (self.max_name_utf16_length > self.uppercase_buf_utf16.len) {
+            self.uppercase_buf_utf16 = try self.allocator.realloc(self.uppercase_buf_utf16, self.max_name_utf16_length);
+        }
+    }
+
+    /// Converts `src` to uppercase using `RtlUpcaseUnicodeString` and puts the result in `dest`.
+    /// Returns the length of the converted UTF-16 string. `dest.len` must be >= `src.len`.
+    ///
+    /// Note: As of now, RtlUpcaseUnicodeString does not seem to handle codepoints above 0x10000
+    /// (i.e. those that require a surrogate pair), so this function will always return a length
+    /// equal to `src.len`. However, if RtlUpcaseUnicodeString is updated to handle codepoints above
+    /// 0x10000, this property would still hold unless there are lowercase <-> uppercase conversions
+    /// that cross over the boundary between codepoints >= 0x10000 and < 0x10000.
+    /// TODO: Is it feasible that Unicode lowercase <-> uppercase conversions could cross that boundary?
+    fn uppercaseName(dest: []u16, src: []const u16) u16 {
+        assert(dest.len >= src.len);
+
+        const dest_bytes = @intCast(u16, dest.len * 2);
+        var dest_string = os.windows.UNICODE_STRING{
+            .Length = dest_bytes,
+            .MaximumLength = dest_bytes,
+            .Buffer = @intToPtr([*]u16, @ptrToInt(dest.ptr)),
+        };
+        const src_bytes = @intCast(u16, src.len * 2);
+        const src_string = os.windows.UNICODE_STRING{
+            .Length = src_bytes,
+            .MaximumLength = src_bytes,
+            .Buffer = @intToPtr([*]u16, @ptrToInt(src.ptr)),
+        };
+        const rc = os.windows.ntdll.RtlUpcaseUnicodeString(&dest_string, &src_string, os.windows.FALSE);
+        switch (rc) {
+            .SUCCESS => return dest_string.Length / 2,
+            else => unreachable, // we are not allocating, so no errors should be possible
+        }
+    }
+
+    /// Note: Does not realloc the uppercase buf to allow for calling put for many variables and
+    /// only allocating the uppercase buf afterwards.
+    pub fn putUtf8(self: *Self, name: []const u8, value: []const u8) !void {
+        const uppercased_len = len: {
+            const name_uppercased_utf16 = uppercased: {
+                var name_utf16_buf = try std.ArrayListAligned(u8, @alignOf(u16)).initCapacity(self.allocator, name.len);
+                errdefer name_utf16_buf.deinit();
+
+                var uppercased_len = try std.unicode.utf8ToUtf16LeWriter(name_utf16_buf.writer(), name);
+                assert(uppercased_len == name_utf16_buf.items.len);
+
+                break :uppercased name_utf16_buf.toOwnedSlice();
+            };
+            errdefer self.allocator.free(name_uppercased_utf16);
+
+            const name_canonical = try self.allocator.dupe(u8, name);
+            errdefer self.allocator.free(name_canonical);
+
+            const value_dupe = try self.allocator.dupe(u8, value);
+            errdefer self.allocator.free(value_dupe);
+
+            const get_or_put = try self.uppercased_map.getOrPut(self.allocator, name_uppercased_utf16);
+            if (get_or_put.found_existing) {
+                // note: this is only safe from UAF because the errdefer that frees this value above
+                // no longer has a possibility of being triggered after this point
+                self.allocator.free(name_uppercased_utf16);
+                self.allocator.free(get_or_put.value_ptr.value);
+                self.allocator.free(get_or_put.value_ptr.canonical_name);
+            } else {
+                get_or_put.key_ptr.* = name_uppercased_utf16;
+            }
+            get_or_put.value_ptr.value = value_dupe;
+            get_or_put.value_ptr.canonical_name = name_canonical;
+
+            break :len name_uppercased_utf16.len;
+        };
+
+        // The buffer for case conversion for key lookups will need to be as big as the largest
+        // key stored in the hash map.
+        self.max_name_utf16_length = @maximum(self.max_name_utf16_length, uppercased_len);
+    }
+
+    /// Asserts that the name does not already exist in the map.
+    /// Note: Does not realloc the uppercase buf to allow for calling put for many variables and
+    /// only allocating the uppercase buf afterwards.
+    pub fn putUtf16NoClobber(self: *Self, name_utf16: []const u16, value_utf16: []const u16) !void {
+        const uppercased_len = len: {
+            const name_canonical = try std.unicode.utf16leToUtf8Alloc(self.allocator, name_utf16);
+            errdefer self.allocator.free(name_canonical);
+
+            const value = try std.unicode.utf16leToUtf8Alloc(self.allocator, value_utf16);
+            errdefer self.allocator.free(value);
+
+            const name_uppercased_utf16 = try self.allocator.alloc(u16, name_utf16.len);
+            errdefer self.allocator.free(name_uppercased_utf16);
+
+            const uppercased_len = uppercaseName(name_uppercased_utf16, name_utf16);
+            assert(uppercased_len == name_uppercased_utf16.len);
+
+            try self.uppercased_map.putNoClobber(self.allocator, std.mem.sliceAsBytes(name_uppercased_utf16), EnvValue{
+                .value = value,
+                .canonical_name = name_canonical,
+            });
+            break :len name_uppercased_utf16.len;
+        };
+
+        // The buffer for case conversion for key lookups will need to be as big as the largest
+        // key stored in the hash map.
+        self.max_name_utf16_length = @maximum(self.max_name_utf16_length, uppercased_len);
+    }
+
+    /// Attempts to convert a UTF-8 name into a uppercased UTF-16le name for a lookup. If the
+    /// name cannot be converted, this function will return `null`.
+    fn utf8ToUppercasedUtf16(self: Self, name: []const u8) ?[]u16 {
+        const name_utf16: []u16 = to_utf16: {
+            var utf16_buf_stream = std.io.fixedBufferStream(std.mem.sliceAsBytes(self.uppercase_buf_utf16));
+            _ = std.unicode.utf8ToUtf16LeWriter(utf16_buf_stream.writer(), name) catch |err| switch (err) {
+                // If the buffer isn't large enough, we can treat that as 'env var not found', as we
+                // know anything too large for the buffer can't be found in the map.
+                error.NoSpaceLeft => return null,
+                // Anything with invalid UTF-8 will also not be found in the map, so treat that as
+                // 'env var not found' too
+                error.InvalidUtf8 => return null,
+            };
+            break :to_utf16 std.mem.bytesAsSlice(u16, utf16_buf_stream.getWritten());
+        };
+
+        // uppercase in place
+        const uppercased_len = uppercaseName(name_utf16, name_utf16);
+        assert(uppercased_len == name_utf16.len);
+
+        return name_utf16;
+    }
+
+    /// Returns true if an entry was found and deleted, false otherwise.
+    pub fn remove(self: *Self, name: []const u8) bool {
+        const name_utf16 = self.utf8ToUppercasedUtf16(name) orelse return false;
+        const kv = self.uppercased_map.fetchRemove(std.mem.sliceAsBytes(name_utf16)) orelse return false;
+        self.allocator.free(kv.key);
+        self.allocator.free(kv.value.value);
+        self.allocator.free(kv.value.canonical_name);
+        return true;
+    }
+
+    pub fn get(self: Self, name: []const u8) ?EnvValue {
+        const name_utf16 = self.utf8ToUppercasedUtf16(name) orelse return null;
+        return self.uppercased_map.get(std.mem.sliceAsBytes(name_utf16));
+    }
+
+    pub fn count(self: Self) EnvMap.Size {
+        return self.uppercased_map.count();
+    }
+
+    pub fn iterator(self: *const Self) Iterator {
+        return .{
+            .env_map = self,
+            .uppercased_map_iterator = self.uppercased_map.iterator(),
+        };
+    }
+
+    pub const Iterator = struct {
+        env_map: *const Self,
+        uppercased_map_iterator: std.StringHashMapUnmanaged(EnvValue).Iterator,
+
+        pub fn next(it: *Iterator) ?EnvMap.Entry {
+            if (it.uppercased_map_iterator.next()) |uppercased_entry| {
+                return EnvMap.Entry{
+                    .name = uppercased_entry.value_ptr.canonical_name,
+                    .value = uppercased_entry.value_ptr.value,
+                };
+            } else {
+                return null;
+            }
+        }
+    };
+};
+
+test "EnvMapWindows" {
+    if (builtin.os.tag != .windows) return error.SkipZigTest;
+
+    var env_map = EnvMapWindows.init(testing.allocator);
+    defer env_map.deinit();
+
+    // both put methods
+    try env_map.putUtf16NoClobber(std.unicode.utf8ToUtf16LeStringLiteral("Path"), std.unicode.utf8ToUtf16LeStringLiteral("something"));
+    try env_map.putUtf8("КИРИЛЛИЦА", "something else");
+    try env_map.reallocUppercaseBuf();
+
+    try testing.expectEqual(@as(EnvMap.Size, 2), env_map.count());
+
+    // unicode-aware case-insensitive lookups
+    try testing.expectEqualStrings("something", env_map.get("PATH").?.value);
+    try testing.expectEqualStrings("something else", env_map.get("кириллица").?.value);
+    try testing.expect(env_map.get("missing") == null);
+
+    // canonical names when iterating
+    var it = env_map.iterator();
+    var count: EnvMap.Size = 0;
+    while (it.next()) |entry| {
+        const is_an_expected_name = std.mem.eql(u8, "Path", entry.name) or std.mem.eql(u8, "КИРИЛЛИЦА", entry.name);
+        try testing.expect(is_an_expected_name);
+        count += 1;
+    }
+    try testing.expectEqual(@as(EnvMap.Size, 2), count);
+}
+
+pub const EnvMap = struct {
+    storage: StorageType,
+
+    pub const StorageType = switch (builtin.os.tag) {
+        .windows => EnvMapWindows,
+        else => std.BufMap,
+    };
+
+    /// Matches what BufMap uses for its internal HashMap Size
+    pub const Size = u32;
+
+    const Self = @This();
+
+    /// Deinitialize with `deinit`.
+    pub fn init(allocator: Allocator) Self {
+        return Self{ .storage = StorageType.init(allocator) };
+    }
+
+    pub fn deinit(self: *Self) void {
+        self.storage.deinit();
+    }
+
+    pub fn get(self: Self, name: []const u8) ?[]const u8 {
+        switch (builtin.os.tag) {
+            .windows => {
+                if (self.storage.get(name)) |entry| {
+                    return entry.value;
+                } else {
+                    return null;
+                }
+            },
+            else => return self.storage.get(name),
+        }
+    }
+
+    pub fn count(self: Self) Size {
+        return self.storage.count();
+    }
+
+    pub fn iterator(self: *const Self) Iterator {
+        return .{ .storage_iterator = self.storage.iterator() };
+    }
+
+    pub fn put(self: *Self, name: []const u8, value: []const u8) !void {
+        switch (builtin.os.tag) {
+            .windows => {
+                try self.storage.putUtf8(name, value);
+                try self.storage.reallocUppercaseBuf();
+            },
+            else => return self.storage.put(name, value),
+        }
+    }
+
+    pub fn remove(self: *Self, name: []const u8) void {
+        _ = self.storage.remove(name);
+    }
+
+    pub const Entry = struct {
+        name: []const u8,
+        value: []const u8,
+    };
+
+    pub const Iterator = struct {
+        storage_iterator: switch (builtin.os.tag) {
+            .windows => EnvMapWindows.Iterator,
+            else => std.BufMap.BufMapHashMap.Iterator,
+        },
+
+        pub fn next(it: *Iterator) ?Entry {
+            switch (builtin.os.tag) {
+                .windows => return it.storage_iterator.next(),
+                else => {
+                    if (it.storage_iterator.next()) |entry| {
+                        return Entry{
+                            .name = entry.key_ptr.*,
+                            .value = entry.value_ptr.*,
+                        };
+                    } else {
+                        return null;
+                    }
+                },
+            }
+        }
+    };
+};
+
+test "EnvMap" {
+    var env = EnvMap.init(testing.allocator);
+    defer env.deinit();
+
+    try env.put("SOMETHING_NEW", "hello");
+    try testing.expectEqualStrings("hello", env.get("SOMETHING_NEW").?);
+    try testing.expectEqual(@as(EnvMap.Size, 1), env.count());
+
+    // overwrite
+    try env.put("SOMETHING_NEW", "something");
+    try testing.expectEqualStrings("something", env.get("SOMETHING_NEW").?);
+    try testing.expectEqual(@as(EnvMap.Size, 1), env.count());
+
+    // a new longer name to test the Windows-specific conversion buffer
+    try env.put("SOMETHING_NEW_AND_LONGER", "1");
+    try testing.expectEqualStrings("1", env.get("SOMETHING_NEW_AND_LONGER").?);
+    try testing.expectEqual(@as(EnvMap.Size, 2), env.count());
+
+    // case insensitivity on Windows only
+    if (builtin.os.tag == .windows) {
+        try testing.expectEqualStrings("1", env.get("something_New_aNd_LONGER").?);
+    } else {
+        try testing.expect(null == env.get("something_New_aNd_LONGER"));
+    }
+
+    var it = env.iterator();
+    var count: EnvMap.Size = 0;
+    while (it.next()) |entry| {
+        const is_an_expected_name = std.mem.eql(u8, "SOMETHING_NEW", entry.name) or std.mem.eql(u8, "SOMETHING_NEW_AND_LONGER", entry.name);
+        try testing.expect(is_an_expected_name);
+        count += 1;
+    }
+    try testing.expectEqual(@as(EnvMap.Size, 2), count);
+
+    env.remove("SOMETHING_NEW");
+    try testing.expect(env.get("SOMETHING_NEW") == null);
+
+    try testing.expectEqual(@as(EnvMap.Size, 1), env.count());
+}
+
+/// Returns a snapshot of the environment variables of the current process.
+/// Any modifications to the resulting EnvMap will not be not reflected in the environment, and
+/// likewise, any future modifications to the environment will not be reflected in the EnvMap.
+/// Caller owns resulting `EnvMap` and should call its `deinit` fn when done.
+pub fn getEnvMap(allocator: Allocator) !EnvMap {
+    var result = EnvMap.init(allocator);
     errdefer result.deinit();
 
     if (builtin.os.tag == .windows) {
@@ -65,23 +440,27 @@ pub fn getEnvMap(allocator: Allocator) !BufMap {
         while (ptr[i] != 0) {
             const key_start = i;
 
+            // There are some special environment variables that start with =,
+            // so we need a special case to not treat = as a key/value separator
+            // if it's the first character.
+            // https://devblogs.microsoft.com/oldnewthing/20100506-00/?p=14133
+            if (ptr[key_start] == '=') i += 1;
+
             while (ptr[i] != 0 and ptr[i] != '=') : (i += 1) {}
             const key_w = ptr[key_start..i];
-            const key = try std.unicode.utf16leToUtf8Alloc(allocator, key_w);
-            errdefer allocator.free(key);
 
             if (ptr[i] == '=') i += 1;
 
             const value_start = i;
             while (ptr[i] != 0) : (i += 1) {}
             const value_w = ptr[value_start..i];
-            const value = try std.unicode.utf16leToUtf8Alloc(allocator, value_w);
-            errdefer allocator.free(value);
 
-            i += 1; // skip over null byte
+            try result.storage.putUtf16NoClobber(key_w, value_w);
 
-            try result.putMove(key, value);
+            i += 1; // skip over null byte
         }
+
+        try result.storage.reallocUppercaseBuf();
         return result;
     } else if (builtin.os.tag == .wasi and !builtin.link_libc) {
         var environ_count: usize = undefined;
@@ -140,8 +519,8 @@ pub fn getEnvMap(allocator: Allocator) !BufMap {
     }
 }
 
-test "os.getEnvMap" {
-    var env = try getEnvMap(std.testing.allocator);
+test "getEnvMap" {
+    var env = try getEnvMap(testing.allocator);
     defer env.deinit();
 }
 
lib/std/unicode.zig
@@ -710,6 +710,29 @@ pub fn utf8ToUtf16Le(utf16le: []u16, utf8: []const u8) !usize {
     return dest_i;
 }
 
+pub fn utf8ToUtf16LeWriter(writer: anytype, utf8: []const u8) !usize {
+    var src_i: usize = 0;
+    var bytes_written: usize = 0;
+    while (src_i < utf8.len) {
+        const n = utf8ByteSequenceLength(utf8[src_i]) catch return error.InvalidUtf8;
+        const next_src_i = src_i + n;
+        const codepoint = utf8Decode(utf8[src_i..next_src_i]) catch return error.InvalidUtf8;
+        if (codepoint < 0x10000) {
+            const short = @intCast(u16, codepoint);
+            try writer.writeIntLittle(u16, short);
+            bytes_written += 2;
+        } else {
+            const high = @intCast(u16, (codepoint - 0x10000) >> 10) + 0xD800;
+            const low = @intCast(u16, codepoint & 0x3FF) + 0xDC00;
+            try writer.writeIntLittle(u16, high);
+            try writer.writeIntLittle(u16, low);
+            bytes_written += 4;
+        }
+        src_i = next_src_i;
+    }
+    return bytes_written;
+}
+
 test "utf8ToUtf16Le" {
     var utf16le: [2]u16 = [_]u16{0} ** 2;
     {