Commit d8c3c11c6c

Jiacai Liu <dev@liujiacai.net>
2023-01-29 23:00:14
std: add expectEqualDeep (#13995)
1 parent 23b7d28
Changed files (1)
lib
lib/std/testing.zig
@@ -670,6 +670,252 @@ pub fn expectStringEndsWith(actual: []const u8, expected_ends_with: []const u8)
     return error.TestExpectedEndsWith;
 }
 
+/// This function is intended to be used only in tests. When the two values are not
+/// deeply equal, prints diagnostics to stderr to show exactly how they are not equal,
+/// then returns a test failure error.
+/// `actual` is casted to the type of `expected`.
+///
+/// Deeply equal is defined as follows:
+/// Primitive types are deeply equal if they are equal using  `==` operator.
+/// Struct values are deeply equal if their corresponding fields are deeply equal.
+/// Container types(like Array/Slice/Vector) deeply equal when their corresponding elements are deeply equal.
+/// Pointer values are deeply equal if values they point to are deeply equal.
+///
+/// Note: Self-referential structs are not supported (e.g. things like std.SinglyLinkedList)
+pub fn expectEqualDeep(expected: anytype, actual: @TypeOf(expected)) !void {
+    switch (@typeInfo(@TypeOf(actual))) {
+        .NoReturn,
+        .Opaque,
+        .Frame,
+        .AnyFrame,
+        => @compileError("value of type " ++ @typeName(@TypeOf(actual)) ++ " encountered"),
+
+        .Undefined,
+        .Null,
+        .Void,
+        => return,
+
+        .Type => {
+            if (actual != expected) {
+                std.debug.print("expected type {s}, found type {s}\n", .{ @typeName(expected), @typeName(actual) });
+                return error.TestExpectedEqual;
+            }
+        },
+
+        .Bool,
+        .Int,
+        .Float,
+        .ComptimeFloat,
+        .ComptimeInt,
+        .EnumLiteral,
+        .Enum,
+        .Fn,
+        .ErrorSet,
+        => {
+            if (actual != expected) {
+                std.debug.print("expected {}, found {}\n", .{ expected, actual });
+                return error.TestExpectedEqual;
+            }
+        },
+
+        .Pointer => |pointer| {
+            switch (pointer.size) {
+                // We have no idea what is behind those pointers, so the best we can do is `==` check.
+                .C, .Many => {
+                    if (actual != expected) {
+                        std.debug.print("expected {*}, found {*}\n", .{ expected, actual });
+                        return error.TestExpectedEqual;
+                    }
+                },
+                .One => {
+                    // Length of those pointers are runtime value, so the best we can do is `==` check.
+                    switch (@typeInfo(pointer.child)) {
+                        .Fn, .Opaque => {
+                            if (actual != expected) {
+                                std.debug.print("expected {*}, found {*}\n", .{ expected, actual });
+                                return error.TestExpectedEqual;
+                            }
+                        },
+                        else => try expectEqualDeep(expected.*, actual.*),
+                    }
+                },
+                .Slice => {
+                    if (expected.len != actual.len) {
+                        std.debug.print("Slice len not the same, expected {d}, found {d}\n", .{ expected.len, actual.len });
+                        return error.TestExpectedEqual;
+                    }
+                    var i: usize = 0;
+                    while (i < expected.len) : (i += 1) {
+                        expectEqualDeep(expected[i], actual[i]) catch |e| {
+                            std.debug.print("index {d} incorrect. expected {any}, found {any}\n", .{
+                                i, expected[i], actual[i],
+                            });
+                            return e;
+                        };
+                    }
+                },
+            }
+        },
+
+        .Array => |_| {
+            if (expected.len != actual.len) {
+                std.debug.print("Array len not the same, expected {d}, found {d}\n", .{ expected.len, actual.len });
+                return error.TestExpectedEqual;
+            }
+            var i: usize = 0;
+            while (i < expected.len) : (i += 1) {
+                expectEqualDeep(expected[i], actual[i]) catch |e| {
+                    std.debug.print("index {d} incorrect. expected {any}, found {any}\n", .{
+                        i, expected[i], actual[i],
+                    });
+                    return e;
+                };
+            }
+        },
+
+        .Vector => |info| {
+            if (info.len != @typeInfo(@TypeOf(actual)).Vector.len) {
+                std.debug.print("Vector len not the same, expected {d}, found {d}\n", .{ info.len, @typeInfo(@TypeOf(actual)).Vector.len });
+                return error.TestExpectedEqual;
+            }
+            var i: usize = 0;
+            while (i < info.len) : (i += 1) {
+                expectEqualDeep(expected[i], actual[i]) catch |e| {
+                    std.debug.print("index {d} incorrect. expected {any}, found {any}\n", .{
+                        i, expected[i], actual[i],
+                    });
+                    return e;
+                };
+            }
+        },
+
+        .Struct => |structType| {
+            inline for (structType.fields) |field| {
+                expectEqualDeep(@field(expected, field.name), @field(actual, field.name)) catch |e| {
+                    std.debug.print("Field {s} incorrect. expected {any}, found {any}\n", .{ field.name, @field(expected, field.name), @field(actual, field.name) });
+                    return e;
+                };
+            }
+        },
+
+        .Union => |union_info| {
+            if (union_info.tag_type == null) {
+                @compileError("Unable to compare untagged union values");
+            }
+
+            const Tag = std.meta.Tag(@TypeOf(expected));
+
+            const expectedTag = @as(Tag, expected);
+            const actualTag = @as(Tag, actual);
+
+            try expectEqual(expectedTag, actualTag);
+
+            // we only reach this loop if the tags are equal
+            switch (expected) {
+                inline else => |val, tag| {
+                    try expectEqualDeep(val, @field(actual, @tagName(tag)));
+                },
+            }
+        },
+
+        .Optional => {
+            if (expected) |expected_payload| {
+                if (actual) |actual_payload| {
+                    try expectEqualDeep(expected_payload, actual_payload);
+                } else {
+                    std.debug.print("expected {any}, found null\n", .{expected_payload});
+                    return error.TestExpectedEqual;
+                }
+            } else {
+                if (actual) |actual_payload| {
+                    std.debug.print("expected null, found {any}\n", .{actual_payload});
+                    return error.TestExpectedEqual;
+                }
+            }
+        },
+
+        .ErrorUnion => {
+            if (expected) |expected_payload| {
+                if (actual) |actual_payload| {
+                    try expectEqualDeep(expected_payload, actual_payload);
+                } else |actual_err| {
+                    std.debug.print("expected {any}, found {any}\n", .{ expected_payload, actual_err });
+                    return error.TestExpectedEqual;
+                }
+            } else |expected_err| {
+                if (actual) |actual_payload| {
+                    std.debug.print("expected {any}, found {any}\n", .{ expected_err, actual_payload });
+                    return error.TestExpectedEqual;
+                } else |actual_err| {
+                    try expectEqualDeep(expected_err, actual_err);
+                }
+            }
+        },
+    }
+}
+
+test "expectEqualDeep primitive type" {
+    try expectEqualDeep(1, 1);
+    try expectEqualDeep(true, true);
+    try expectEqualDeep(1.5, 1.5);
+    try expectEqualDeep(u8, u8);
+    try expectEqualDeep(error.Bad, error.Bad);
+
+    // optional
+    {
+        const foo: ?u32 = 1;
+        const bar: ?u32 = 1;
+        try expectEqualDeep(foo, bar);
+        try expectEqualDeep(?u32, ?u32);
+    }
+    // function type
+    {
+        const fnType = struct {
+            fn foo() void {
+                unreachable;
+            }
+        }.foo;
+        try expectEqualDeep(fnType, fnType);
+    }
+}
+
+test "expectEqualDeep pointer" {
+    const a = 1;
+    const b = 1;
+    try expectEqualDeep(&a, &b);
+}
+
+test "expectEqualDeep composite type" {
+    try expectEqualDeep("abc", "abc");
+    const s1: []const u8 = "abc";
+    const s2 = "abcd";
+    const s3: []const u8 = s2[0..3];
+    try expectEqualDeep(s1, s3);
+
+    const TestStruct = struct { s: []const u8 };
+    try expectEqualDeep(TestStruct{ .s = "abc" }, TestStruct{ .s = "abc" });
+    try expectEqualDeep([_][]const u8{ "a", "b", "c" }, [_][]const u8{ "a", "b", "c" });
+
+    // vector
+    try expectEqualDeep(@splat(4, @as(u32, 4)), @splat(4, @as(u32, 4)));
+
+    // nested array
+    {
+        const a = [2][2]f32{
+            [_]f32{ 1.0, 0.0 },
+            [_]f32{ 0.0, 1.0 },
+        };
+
+        const b = [2][2]f32{
+            [_]f32{ 1.0, 0.0 },
+            [_]f32{ 0.0, 1.0 },
+        };
+
+        try expectEqualDeep(a, b);
+        try expectEqualDeep(&a, &b);
+    }
+}
+
 fn printIndicatorLine(source: []const u8, indicator_index: usize) void {
     const line_begin_index = if (std.mem.lastIndexOfScalar(u8, source[0..indicator_index], '\n')) |line_begin|
         line_begin + 1