Commit 6d04ab6d5b

Ryan Liptak <squeek502@hotmail.com>
2022-01-13 09:35:50
Add `std.testing.checkAllAllocationFailures`
Adds a function that allows checking for memory leaks (and other problems) by taking advantage of the FailingAllocator and inducing failure at every allocation point within the provided `test_fn` (based on the strategy employed in the Zig parser tests, which can now use this function).
1 parent 91eb1af
Changed files (2)
lib/std/zig/parser_test.zig
@@ -5459,52 +5459,24 @@ fn testParse(source: [:0]const u8, allocator: mem.Allocator, anything_changed: *
     anything_changed.* = !mem.eql(u8, formatted, source);
     return formatted;
 }
-fn testTransform(source: [:0]const u8, expected_source: []const u8) !void {
-    const needed_alloc_count = x: {
-        // Try it once with unlimited memory, make sure it works
-        var fixed_allocator = std.heap.FixedBufferAllocator.init(fixed_buffer_mem[0..]);
-        var failing_allocator = std.testing.FailingAllocator.init(fixed_allocator.allocator(), maxInt(usize));
-        const allocator = failing_allocator.allocator();
-        var anything_changed: bool = undefined;
-        const result_source = try testParse(source, allocator, &anything_changed);
-        try std.testing.expectEqualStrings(expected_source, result_source);
-        const changes_expected = source.ptr != expected_source.ptr;
-        if (anything_changed != changes_expected) {
-            print("std.zig.render returned {} instead of {}\n", .{ anything_changed, changes_expected });
-            return error.TestFailed;
-        }
-        try std.testing.expect(anything_changed == changes_expected);
-        allocator.free(result_source);
-        break :x failing_allocator.index;
-    };
-
-    var fail_index: usize = 0;
-    while (fail_index < needed_alloc_count) : (fail_index += 1) {
-        var fixed_allocator = std.heap.FixedBufferAllocator.init(fixed_buffer_mem[0..]);
-        var failing_allocator = std.testing.FailingAllocator.init(fixed_allocator.allocator(), fail_index);
-        var anything_changed: bool = undefined;
-        if (testParse(source, failing_allocator.allocator(), &anything_changed)) |_| {
-            return error.NondeterministicMemoryUsage;
-        } else |err| switch (err) {
-            error.OutOfMemory => {
-                if (failing_allocator.allocated_bytes != failing_allocator.freed_bytes) {
-                    print(
-                        "\nfail_index: {d}/{d}\nallocated bytes: {d}\nfreed bytes: {d}\nallocations: {d}\ndeallocations: {d}\n",
-                        .{
-                            fail_index,
-                            needed_alloc_count,
-                            failing_allocator.allocated_bytes,
-                            failing_allocator.freed_bytes,
-                            failing_allocator.allocations,
-                            failing_allocator.deallocations,
-                        },
-                    );
-                    return error.MemoryLeakDetected;
-                }
-            },
-            else => return err,
-        }
+fn testTransformImpl(allocator: mem.Allocator, fba: *std.heap.FixedBufferAllocator, source: [:0]const u8, expected_source: []const u8) !void {
+    // reset the fixed buffer allocator each run so that it can be re-used for each
+    // iteration of the failing index
+    fba.reset();
+    var anything_changed: bool = undefined;
+    const result_source = try testParse(source, allocator, &anything_changed);
+    try std.testing.expectEqualStrings(expected_source, result_source);
+    const changes_expected = source.ptr != expected_source.ptr;
+    if (anything_changed != changes_expected) {
+        print("std.zig.render returned {} instead of {}\n", .{ anything_changed, changes_expected });
+        return error.TestFailed;
     }
+    try std.testing.expect(anything_changed == changes_expected);
+    allocator.free(result_source);
+}
+fn testTransform(source: [:0]const u8, expected_source: []const u8) !void {
+    var fixed_allocator = std.heap.FixedBufferAllocator.init(fixed_buffer_mem[0..]);
+    return std.testing.checkAllAllocationFailures(fixed_allocator.allocator(), testTransformImpl, .{ &fixed_allocator, source, expected_source });
 }
 fn testCanonical(source: [:0]const u8) !void {
     return testTransform(source, source);
lib/std/testing.zig
@@ -574,6 +574,150 @@ test {
     try expectEqualStrings("foo", "foo");
 }
 
+/// Exhaustively check that allocation failures within `test_fn` are handled without
+/// introducing memory leaks. If used with the `testing.allocator` as the `backing_allocator`,
+/// it will also be able to detect double frees, etc (when runtime safety is enabled).
+///
+/// The provided `test_fn` must have a `std.mem.Allocator` as its first argument,
+/// and must have a return type of `!void`. Any extra arguments of `test_fn` can
+/// be provided via the `extra_args` tuple.
+///
+/// Any relevant state shared between runs of `test_fn` *must* be reset within `test_fn`.
+///
+/// Expects that the `test_fn` has a deterministic number of memory allocations
+/// (an error will be returned if non-deterministic allocations are detected).
+///
+/// The strategy employed is to:
+/// - Run the test function once to get the total number of allocations.
+/// - Then, iterate and run the function X more times, incrementing
+///   the failing index each iteration (where X is the total number of
+///   allocations determined previously)
+///
+/// ---
+///
+/// Here's an example of using a simple test case that will cause a leak when the
+/// allocation of `bar` fails (but will pass normally):
+///
+/// ```zig
+/// test {
+///     const length: usize = 10;
+///     const allocator = std.testing.allocator;
+///     var foo = try allocator.alloc(u8, length);
+///     var bar = try allocator.alloc(u8, length);
+///
+///     allocator.free(foo);
+///     allocator.free(bar);
+/// }
+/// ```
+///
+/// The test case can be converted to something that this function can use by
+/// doing:
+///
+/// ```zig
+/// fn testImpl(allocator: std.mem.Allocator, length: usize) !void {
+///     var foo = try allocator.alloc(u8, length);
+///     var bar = try allocator.alloc(u8, length);
+///
+///     allocator.free(foo);
+///     allocator.free(bar);
+/// }
+///
+/// test {
+///     const length: usize = 10;
+///     const allocator = std.testing.allocator;
+///     try std.testing.checkAllAllocationFailures(allocator, testImpl, .{length});
+/// }
+/// ```
+///
+/// Running this test will show that `foo` is leaked when the allocation of
+/// `bar` fails. The simplest fix, in this case, would be to use defer like so:
+///
+/// ```zig
+/// fn testImpl(allocator: std.mem.Allocator, length: usize) !void {
+///     var foo = try allocator.alloc(u8, length);
+///     defer allocator.free(foo);
+///     var bar = try allocator.alloc(u8, length);
+///     defer allocator.free(bar);
+/// }
+/// ```
+pub fn checkAllAllocationFailures(backing_allocator: std.mem.Allocator, comptime test_fn: anytype, extra_args: anytype) !void {
+    switch (@typeInfo(@typeInfo(@TypeOf(test_fn)).Fn.return_type.?)) {
+        .ErrorUnion => |info| {
+            if (info.payload != void) {
+                @compileError("Return type must be !void");
+            }
+        },
+        else => @compileError("Return type must be !void"),
+    }
+    if (@typeInfo(@TypeOf(extra_args)) != .Struct) {
+        @compileError("Expected tuple or struct argument, found " ++ @typeName(@TypeOf(extra_args)));
+    }
+
+    const ArgsTuple = std.meta.ArgsTuple(@TypeOf(test_fn));
+    const fn_args_fields = @typeInfo(ArgsTuple).Struct.fields;
+    if (fn_args_fields.len == 0 or fn_args_fields[0].field_type != std.mem.Allocator) {
+        @compileError("The provided function must have an " ++ @typeName(std.mem.Allocator) ++ " as its first argument");
+    }
+    const expected_args_tuple_len = fn_args_fields.len - 1;
+    if (extra_args.len != expected_args_tuple_len) {
+        @compileError("The provided function expects " ++ (comptime std.fmt.comptimePrint("{d}", .{expected_args_tuple_len})) ++ " extra arguments, but the provided tuple contains " ++ (comptime std.fmt.comptimePrint("{d}", .{extra_args.len})));
+    }
+
+    // Setup the tuple that will actually be used with @call (we'll need to insert
+    // the failing allocator in field @"0" before each @call)
+    var args: ArgsTuple = undefined;
+    inline for (@typeInfo(@TypeOf(extra_args)).Struct.fields) |field, i| {
+        const expected_type = fn_args_fields[i + 1].field_type;
+        if (expected_type != field.field_type) {
+            @compileError("Unexpected type for extra argument at index " ++ (comptime std.fmt.comptimePrint("{d}", .{i})) ++ ": expected " ++ @typeName(expected_type) ++ ", found " ++ @typeName(field.field_type));
+        }
+        const arg_i_str = comptime str: {
+            var str_buf: [100]u8 = undefined;
+            const args_i = i + 1;
+            const str_len = std.fmt.formatIntBuf(&str_buf, args_i, 10, .lower, .{});
+            break :str str_buf[0..str_len];
+        };
+        @field(args, arg_i_str) = @field(extra_args, field.name);
+    }
+
+    // Try it once with unlimited memory, make sure it works
+    const needed_alloc_count = x: {
+        var failing_allocator_inst = std.testing.FailingAllocator.init(backing_allocator, std.math.maxInt(usize));
+        args.@"0" = failing_allocator_inst.allocator();
+
+        try @call(.{}, test_fn, args);
+        break :x failing_allocator_inst.index;
+    };
+
+    var fail_index: usize = 0;
+    while (fail_index < needed_alloc_count) : (fail_index += 1) {
+        var failing_allocator_inst = std.testing.FailingAllocator.init(backing_allocator, fail_index);
+        args.@"0" = failing_allocator_inst.allocator();
+
+        if (@call(.{}, test_fn, args)) |_| {
+            return error.NondeterministicMemoryUsage;
+        } else |err| switch (err) {
+            error.OutOfMemory => {
+                if (failing_allocator_inst.allocated_bytes != failing_allocator_inst.freed_bytes) {
+                    print(
+                        "\nfail_index: {d}/{d}\nallocated bytes: {d}\nfreed bytes: {d}\nallocations: {d}\ndeallocations: {d}\n",
+                        .{
+                            fail_index,
+                            needed_alloc_count,
+                            failing_allocator_inst.allocated_bytes,
+                            failing_allocator_inst.freed_bytes,
+                            failing_allocator_inst.allocations,
+                            failing_allocator_inst.deallocations,
+                        },
+                    );
+                    return error.MemoryLeakDetected;
+                }
+            },
+            else => return err,
+        }
+    }
+}
+
 /// Given a type, reference all the declarations inside, so that the semantic analyzer sees them.
 pub fn refAllDecls(comptime T: type) void {
     if (!builtin.is_test) return;